mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
Correct HID type annotations
This commit is contained in:
@@ -23,13 +23,12 @@ import struct
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
from typing import Optional, Callable, TYPE_CHECKING
|
from typing import Optional, Callable
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from bumble import l2cap, device
|
from bumble import l2cap, device
|
||||||
from bumble.colors import color
|
|
||||||
from bumble.core import InvalidStateError, ProtocolError
|
from bumble.core import InvalidStateError, ProtocolError
|
||||||
from .hci import Address
|
from bumble.hci import Address
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
|
|||||||
async def connect_control_channel(self) -> None:
|
async def connect_control_channel(self) -> None:
|
||||||
# Create a new L2CAP connection - control channel
|
# Create a new L2CAP connection - control channel
|
||||||
try:
|
try:
|
||||||
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
|
channel = await self.device.l2cap_channel_manager.connect(
|
||||||
self.connection, HID_CONTROL_PSM
|
self.connection, HID_CONTROL_PSM
|
||||||
)
|
)
|
||||||
|
channel.sink = self.on_ctrl_pdu
|
||||||
|
self.l2cap_ctrl_channel = channel
|
||||||
except ProtocolError:
|
except ProtocolError:
|
||||||
logging.exception(f'L2CAP connection failed.')
|
logging.exception(f'L2CAP connection failed.')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert self.l2cap_ctrl_channel is not None
|
|
||||||
# Become a sink for the L2CAP channel
|
|
||||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
|
||||||
|
|
||||||
async def connect_interrupt_channel(self) -> None:
|
async def connect_interrupt_channel(self) -> None:
|
||||||
# Create a new L2CAP connection - interrupt channel
|
# Create a new L2CAP connection - interrupt channel
|
||||||
try:
|
try:
|
||||||
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
|
channel = await self.device.l2cap_channel_manager.connect(
|
||||||
self.connection, HID_INTERRUPT_PSM
|
self.connection, HID_INTERRUPT_PSM
|
||||||
)
|
)
|
||||||
|
channel.sink = self.on_intr_pdu
|
||||||
|
self.l2cap_intr_channel = channel
|
||||||
except ProtocolError:
|
except ProtocolError:
|
||||||
logging.exception(f'L2CAP connection failed.')
|
logging.exception(f'L2CAP connection failed.')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert self.l2cap_intr_channel is not None
|
|
||||||
# Become a sink for the L2CAP channel
|
|
||||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
|
||||||
|
|
||||||
async def disconnect_interrupt_channel(self) -> None:
|
async def disconnect_interrupt_channel(self) -> None:
|
||||||
if self.l2cap_intr_channel is None:
|
if self.l2cap_intr_channel is None:
|
||||||
raise InvalidStateError('invalid state')
|
raise InvalidStateError('invalid state')
|
||||||
@@ -334,17 +329,18 @@ class Device(HID):
|
|||||||
ERR_INVALID_PARAMETER = 0x04
|
ERR_INVALID_PARAMETER = 0x04
|
||||||
SUCCESS = 0xFF
|
SUCCESS = 0xFF
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class GetSetStatus:
|
class GetSetStatus:
|
||||||
def __init__(self) -> None:
|
data: bytes = b''
|
||||||
self.data = bytearray()
|
status: int = 0
|
||||||
self.status = 0
|
|
||||||
|
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
|
||||||
|
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
|
||||||
|
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
|
||||||
|
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
|
||||||
|
|
||||||
def __init__(self, device: device.Device) -> None:
|
def __init__(self, device: device.Device) -> None:
|
||||||
super().__init__(device, HID.Role.DEVICE)
|
super().__init__(device, HID.Role.DEVICE)
|
||||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
|
||||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
|
||||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
|
||||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||||
@@ -410,7 +406,6 @@ class Device(HID):
|
|||||||
buffer_size = 0
|
buffer_size = 0
|
||||||
|
|
||||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.FAILURE:
|
if ret.status == self.GetSetReturn.FAILURE:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||||
@@ -428,7 +423,9 @@ class Device(HID):
|
|||||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
def register_get_report_cb(
|
||||||
|
self, cb: Callable[[int, int, int], Device.GetSetStatus]
|
||||||
|
) -> None:
|
||||||
self.get_report_cb = cb
|
self.get_report_cb = cb
|
||||||
logger.debug("GetReport callback registered successfully")
|
logger.debug("GetReport callback registered successfully")
|
||||||
|
|
||||||
@@ -442,7 +439,6 @@ class Device(HID):
|
|||||||
report_data = pdu[2:]
|
report_data = pdu[2:]
|
||||||
report_size = len(report_data) + 1
|
report_size = len(report_data) + 1
|
||||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||||
@@ -453,7 +449,7 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_report_cb(
|
def register_set_report_cb(
|
||||||
self, cb: Callable[[int, int, int, bytes], None]
|
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.set_report_cb = cb
|
self.set_report_cb = cb
|
||||||
logger.debug("SetReport callback registered successfully")
|
logger.debug("SetReport callback registered successfully")
|
||||||
@@ -464,13 +460,12 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
return
|
return
|
||||||
ret = self.get_protocol_cb()
|
ret = self.get_protocol_cb()
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||||
else:
|
else:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
|
||||||
self.get_protocol_cb = cb
|
self.get_protocol_cb = cb
|
||||||
logger.debug("GetProtocol callback registered successfully")
|
logger.debug("GetProtocol callback registered successfully")
|
||||||
|
|
||||||
@@ -480,13 +475,14 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
return
|
return
|
||||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
else:
|
else:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
def register_set_protocol_cb(
|
||||||
|
self, cb: Callable[[int], Device.GetSetStatus]
|
||||||
|
) -> None:
|
||||||
self.set_protocol_cb = cb
|
self.set_protocol_cb = cb
|
||||||
logger.debug("SetProtocol callback registered successfully")
|
logger.debug("SetProtocol callback registered successfully")
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import websockets
|
import websockets
|
||||||
from bumble.colors import color
|
import struct
|
||||||
|
|
||||||
from bumble.device import Device
|
from bumble.device import Device
|
||||||
from bumble.transport import open_transport_or_link
|
from bumble.transport import open_transport_or_link
|
||||||
@@ -30,9 +30,7 @@ from bumble.core import (
|
|||||||
BT_L2CAP_PROTOCOL_ID,
|
BT_L2CAP_PROTOCOL_ID,
|
||||||
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
||||||
BT_HIDP_PROTOCOL_ID,
|
BT_HIDP_PROTOCOL_ID,
|
||||||
UUID,
|
|
||||||
)
|
)
|
||||||
from bumble.hci import Address
|
|
||||||
from bumble.hid import (
|
from bumble.hid import (
|
||||||
Device as HID_Device,
|
Device as HID_Device,
|
||||||
HID_CONTROL_PSM,
|
HID_CONTROL_PSM,
|
||||||
@@ -40,20 +38,17 @@ from bumble.hid import (
|
|||||||
Message,
|
Message,
|
||||||
)
|
)
|
||||||
from bumble.sdp import (
|
from bumble.sdp import (
|
||||||
Client as SDP_Client,
|
|
||||||
DataElement,
|
DataElement,
|
||||||
ServiceAttribute,
|
ServiceAttribute,
|
||||||
SDP_PUBLIC_BROWSE_ROOT,
|
SDP_PUBLIC_BROWSE_ROOT,
|
||||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_ALL_ATTRIBUTES_RANGE,
|
|
||||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
)
|
)
|
||||||
from bumble.utils import AsyncRunner
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# SDP attributes for Bluetooth HID devices
|
# SDP attributes for Bluetooth HID devices
|
||||||
@@ -430,7 +425,7 @@ deviceData = DeviceData()
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def keyboard_device(hid_device):
|
async def keyboard_device(hid_device: HID_Device):
|
||||||
|
|
||||||
# Start a Websocket server to receive events from a web page
|
# Start a Websocket server to receive events from a web page
|
||||||
async def serve(websocket, _path):
|
async def serve(websocket, _path):
|
||||||
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
|
|||||||
# limiting x and y values within logical max and min range
|
# limiting x and y values within logical max and min range
|
||||||
x = max(log_min, min(log_max, x))
|
x = max(log_min, min(log_max, x))
|
||||||
y = max(log_min, min(log_max, y))
|
y = max(log_min, min(log_max, y))
|
||||||
x_cord = x.to_bytes(signed=True)
|
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
|
||||||
y_cord = y.to_bytes(signed=True)
|
">bb", x, y
|
||||||
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
|
)
|
||||||
hid_device.send_data(deviceData.mouseData)
|
hid_device.send_data(deviceData.mouseData)
|
||||||
except websockets.exceptions.ConnectionClosedOK:
|
except websockets.exceptions.ConnectionClosedOK:
|
||||||
pass
|
pass
|
||||||
@@ -515,7 +510,9 @@ async def main() -> None:
|
|||||||
def on_hid_data_cb(pdu: bytes):
|
def on_hid_data_cb(pdu: bytes):
|
||||||
print(f'Received Data, PDU: {pdu.hex()}')
|
print(f'Received Data, PDU: {pdu.hex()}')
|
||||||
|
|
||||||
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
|
def on_get_report_cb(
|
||||||
|
report_id: int, report_type: int, buffer_size: int
|
||||||
|
) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
retValue = hid_device.GetSetStatus()
|
||||||
print(
|
print(
|
||||||
"GET_REPORT report_id: "
|
"GET_REPORT report_id: "
|
||||||
@@ -555,8 +552,7 @@ async def main() -> None:
|
|||||||
|
|
||||||
def on_set_report_cb(
|
def on_set_report_cb(
|
||||||
report_id: int, report_type: int, report_size: int, data: bytes
|
report_id: int, report_type: int, report_size: int, data: bytes
|
||||||
):
|
) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
|
||||||
print(
|
print(
|
||||||
"SET_REPORT report_id: "
|
"SET_REPORT report_id: "
|
||||||
+ str(report_id)
|
+ str(report_id)
|
||||||
@@ -568,33 +564,33 @@ async def main() -> None:
|
|||||||
+ str(data)
|
+ str(data)
|
||||||
)
|
)
|
||||||
if report_type == Message.ReportType.FEATURE_REPORT:
|
if report_type == Message.ReportType.FEATURE_REPORT:
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_type == Message.ReportType.INPUT_REPORT:
|
elif report_type == Message.ReportType.INPUT_REPORT:
|
||||||
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_id == 3:
|
elif report_id == 3:
|
||||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||||
else:
|
else:
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status = HID_Device.GetSetReturn.SUCCESS
|
||||||
else:
|
else:
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status = HID_Device.GetSetReturn.SUCCESS
|
||||||
|
|
||||||
return retValue
|
return HID_Device.GetSetStatus(status=status)
|
||||||
|
|
||||||
def on_get_protocol_cb():
|
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
return HID_Device.GetSetStatus(
|
||||||
retValue.data = protocol_mode.to_bytes()
|
data=bytes([protocol_mode]),
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status=hid_device.GetSetReturn.SUCCESS,
|
||||||
return retValue
|
)
|
||||||
|
|
||||||
def on_set_protocol_cb(protocol: int):
|
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
|
||||||
# We do not support SET_PROTOCOL.
|
# We do not support SET_PROTOCOL.
|
||||||
print(f"SET_PROTOCOL report_id: {protocol}")
|
print(f"SET_PROTOCOL report_id: {protocol}")
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
return HID_Device.GetSetStatus(
|
||||||
return retValue
|
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
def on_virtual_cable_unplug_cb():
|
def on_virtual_cable_unplug_cb():
|
||||||
print('Received Virtual Cable Unplug')
|
print('Received Virtual Cable Unplug')
|
||||||
|
|||||||
Reference in New Issue
Block a user