From 2248f9ae5edc714bacbb3aa86c99a3f034ef5c84 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Tue, 13 Aug 2024 22:52:00 +0800 Subject: [PATCH] Correct HID type annotations --- bumble/hid.py | 52 +++++++++++++++++------------------- examples/run_hid_device.py | 54 ++++++++++++++++++-------------------- 2 files changed, 49 insertions(+), 57 deletions(-) diff --git a/bumble/hid.py b/bumble/hid.py index 1b4aa003..d4a2a721 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -23,13 +23,12 @@ import struct from abc import ABC, abstractmethod from pyee import EventEmitter -from typing import Optional, Callable, TYPE_CHECKING +from typing import Optional, Callable from typing_extensions import override from bumble import l2cap, device -from bumble.colors import color from bumble.core import InvalidStateError, ProtocolError -from .hci import Address +from bumble.hci import Address # ----------------------------------------------------------------------------- @@ -220,31 +219,27 @@ class HID(ABC, EventEmitter): async def connect_control_channel(self) -> None: # Create a new L2CAP connection - control channel try: - self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( + channel = await self.device.l2cap_channel_manager.connect( self.connection, HID_CONTROL_PSM ) + channel.sink = self.on_ctrl_pdu + self.l2cap_ctrl_channel = channel except ProtocolError: logging.exception(f'L2CAP connection failed.') raise - assert self.l2cap_ctrl_channel is not None - # Become a sink for the L2CAP channel - self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu - async def connect_interrupt_channel(self) -> None: # Create a new L2CAP connection - interrupt channel try: - self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( + channel = await self.device.l2cap_channel_manager.connect( self.connection, HID_INTERRUPT_PSM ) + channel.sink = self.on_intr_pdu + self.l2cap_intr_channel = channel except ProtocolError: logging.exception(f'L2CAP connection failed.') raise - assert self.l2cap_intr_channel is not None - # Become a sink for the L2CAP channel - self.l2cap_intr_channel.sink = self.on_intr_pdu - async def disconnect_interrupt_channel(self) -> None: if self.l2cap_intr_channel is None: raise InvalidStateError('invalid state') @@ -334,17 +329,18 @@ class Device(HID): ERR_INVALID_PARAMETER = 0x04 SUCCESS = 0xFF + @dataclass class GetSetStatus: - def __init__(self) -> None: - self.data = bytearray() - self.status = 0 + data: bytes = b'' + status: int = 0 + + get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None + set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None + get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None + set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None def __init__(self, device: device.Device) -> None: super().__init__(device, HID.Role.DEVICE) - get_report_cb: Optional[Callable[[int, int, int], None]] = None - set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None - get_protocol_cb: Optional[Callable[[], None]] = None - set_protocol_cb: Optional[Callable[[int], None]] = None @override def on_ctrl_pdu(self, pdu: bytes) -> None: @@ -410,7 +406,6 @@ class Device(HID): buffer_size = 0 ret = self.get_report_cb(report_id, report_type, buffer_size) - assert ret is not None if ret.status == self.GetSetReturn.FAILURE: self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) elif ret.status == self.GetSetReturn.SUCCESS: @@ -428,7 +423,9 @@ class Device(HID): elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None: + def register_get_report_cb( + self, cb: Callable[[int, int, int], Device.GetSetStatus] + ) -> None: self.get_report_cb = cb logger.debug("GetReport callback registered successfully") @@ -442,7 +439,6 @@ class Device(HID): report_data = pdu[2:] report_size = len(report_data) + 1 ret = self.set_report_cb(report_id, report_type, report_size, report_data) - assert ret is not None if ret.status == self.GetSetReturn.SUCCESS: self.send_handshake_message(Message.Handshake.SUCCESSFUL) elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: @@ -453,7 +449,7 @@ class Device(HID): self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def register_set_report_cb( - self, cb: Callable[[int, int, int, bytes], None] + self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus] ) -> None: self.set_report_cb = cb logger.debug("SetReport callback registered successfully") @@ -464,13 +460,12 @@ class Device(HID): self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) return ret = self.get_protocol_cb() - assert ret is not None if ret.status == self.GetSetReturn.SUCCESS: self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) else: self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - def register_get_protocol_cb(self, cb: Callable[[], None]) -> None: + def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None: self.get_protocol_cb = cb logger.debug("GetProtocol callback registered successfully") @@ -480,13 +475,14 @@ class Device(HID): self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) return ret = self.set_protocol_cb(pdu[0] & 0x01) - assert ret is not None if ret.status == self.GetSetReturn.SUCCESS: self.send_handshake_message(Message.Handshake.SUCCESSFUL) else: self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None: + def register_set_protocol_cb( + self, cb: Callable[[int], Device.GetSetStatus] + ) -> None: self.set_protocol_cb = cb logger.debug("SetProtocol callback registered successfully") diff --git a/examples/run_hid_device.py b/examples/run_hid_device.py index 2287be09..160e3952 100644 --- a/examples/run_hid_device.py +++ b/examples/run_hid_device.py @@ -21,7 +21,7 @@ import os import logging import json import websockets -from bumble.colors import color +import struct from bumble.device import Device from bumble.transport import open_transport_or_link @@ -30,9 +30,7 @@ from bumble.core import ( BT_L2CAP_PROTOCOL_ID, BT_HUMAN_INTERFACE_DEVICE_SERVICE, BT_HIDP_PROTOCOL_ID, - UUID, ) -from bumble.hci import Address from bumble.hid import ( Device as HID_Device, HID_CONTROL_PSM, @@ -40,20 +38,17 @@ from bumble.hid import ( Message, ) from bumble.sdp import ( - Client as SDP_Client, DataElement, ServiceAttribute, SDP_PUBLIC_BROWSE_ROOT, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_ALL_ATTRIBUTES_RANGE, SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, ) -from bumble.utils import AsyncRunner # ----------------------------------------------------------------------------- # SDP attributes for Bluetooth HID devices @@ -430,7 +425,7 @@ deviceData = DeviceData() # ----------------------------------------------------------------------------- -async def keyboard_device(hid_device): +async def keyboard_device(hid_device: HID_Device): # Start a Websocket server to receive events from a web page async def serve(websocket, _path): @@ -476,9 +471,9 @@ async def keyboard_device(hid_device): # limiting x and y values within logical max and min range x = max(log_min, min(log_max, x)) y = max(log_min, min(log_max, y)) - x_cord = x.to_bytes(signed=True) - y_cord = y.to_bytes(signed=True) - deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord + deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack( + ">bb", x, y + ) hid_device.send_data(deviceData.mouseData) except websockets.exceptions.ConnectionClosedOK: pass @@ -515,7 +510,9 @@ async def main() -> None: def on_hid_data_cb(pdu: bytes): print(f'Received Data, PDU: {pdu.hex()}') - def on_get_report_cb(report_id: int, report_type: int, buffer_size: int): + def on_get_report_cb( + report_id: int, report_type: int, buffer_size: int + ) -> HID_Device.GetSetStatus: retValue = hid_device.GetSetStatus() print( "GET_REPORT report_id: " @@ -555,8 +552,7 @@ async def main() -> None: def on_set_report_cb( report_id: int, report_type: int, report_size: int, data: bytes - ): - retValue = hid_device.GetSetStatus() + ) -> HID_Device.GetSetStatus: print( "SET_REPORT report_id: " + str(report_id) @@ -568,33 +564,33 @@ async def main() -> None: + str(data) ) if report_type == Message.ReportType.FEATURE_REPORT: - retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER + status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER elif report_type == Message.ReportType.INPUT_REPORT: if report_id == 1 and report_size != len(deviceData.keyboardData): - retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER + status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER elif report_id == 2 and report_size != len(deviceData.mouseData): - retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER + status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER elif report_id == 3: - retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND + status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND else: - retValue.status = hid_device.GetSetReturn.SUCCESS + status = HID_Device.GetSetReturn.SUCCESS else: - retValue.status = hid_device.GetSetReturn.SUCCESS + status = HID_Device.GetSetReturn.SUCCESS - return retValue + return HID_Device.GetSetStatus(status=status) - def on_get_protocol_cb(): - retValue = hid_device.GetSetStatus() - retValue.data = protocol_mode.to_bytes() - retValue.status = hid_device.GetSetReturn.SUCCESS - return retValue + def on_get_protocol_cb() -> HID_Device.GetSetStatus: + return HID_Device.GetSetStatus( + data=bytes([protocol_mode]), + status=hid_device.GetSetReturn.SUCCESS, + ) - def on_set_protocol_cb(protocol: int): - retValue = hid_device.GetSetStatus() + def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus: # We do not support SET_PROTOCOL. print(f"SET_PROTOCOL report_id: {protocol}") - retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST - return retValue + return HID_Device.GetSetStatus( + status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST + ) def on_virtual_cable_unplug_cb(): print('Received Virtual Cable Unplug')