diff --git a/bumble/hid.py b/bumble/hid.py index daf0d70..347c218 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -22,7 +22,7 @@ import enum import struct from pyee import EventEmitter -from typing import Optional, TYPE_CHECKING +from typing import Optional, Callable, TYPE_CHECKING from bumble import l2cap, device from bumble.colors import color @@ -193,8 +193,9 @@ class SendHandshakeMessage(Message): # ----------------------------------------------------------------------------- class HID(EventEmitter): - l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] - l2cap_intr_channel: Optional[l2cap.ClassicChannel] + l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None + l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None + connection: Optional[device.Connection] = None class Role(enum.IntEnum): HOST = 0x00 @@ -203,19 +204,29 @@ class HID(EventEmitter): def __init__(self, device: device.Device, role: Role) -> None: super().__init__() self.device = device - self.connection = None - self.remote_device_bd_address = None self.role = role - self.l2cap_ctrl_channel = None - self.l2cap_intr_channel = None - # Register ourselves with the L2CAP channel manager - device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection) - device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection) + device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection) + device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection) device.on('connection', self.on_device_connection) + def handle_get_report(self, pdu: bytes): + return + + def handle_set_report(self, pdu: bytes): + return + + def handle_get_protocol(self, pdu: bytes): + return + + def handle_set_protocol(self, pdu: bytes): + return + + def send_handshake_message(self, result_code: int): + return + async def connect_control_channel(self) -> None: # Create a new L2CAP connection - control channel try: @@ -259,20 +270,17 @@ class HID(EventEmitter): await channel.disconnect() def on_device_connection(self, connection: device.Connection) -> None: - self.connection = connection # type: ignore[assignment] - self.remote_device_bd_address = ( - connection.peer_address # type: ignore[assignment] - ) - connection.on('disconnection', self.on_disconnection) + self.connection = connection + connection.on('disconnection', self.on_device_disconnection) - def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: + def on_device_disconnection(self, reason: int) -> None: + self.connection = None + + def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel)) - def on_disconnection(self, reason: int) -> None: - self.connection = None - def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: if l2cap_channel.psm == HID_CONTROL_PSM: self.l2cap_ctrl_channel = l2cap_channel @@ -299,16 +307,16 @@ class HID(EventEmitter): self.emit('handshake', Message.Handshake(param)) elif message_type == Message.MessageType.GET_REPORT: logger.debug('<<< HID GET REPORT') - self.handle_get_report(pdu) # type: ignore[attr-defined] + self.handle_get_report(pdu) elif message_type == Message.MessageType.SET_REPORT: logger.debug('<<< HID SET REPORT') - self.handle_set_report(pdu) # type: ignore[attr-defined] + self.handle_set_report(pdu) elif message_type == Message.MessageType.GET_PROTOCOL: logger.debug('<<< HID GET PROTOCOL') - self.handle_get_protocol(pdu) # type: ignore[attr-defined] + self.handle_get_protocol(pdu) elif message_type == Message.MessageType.SET_PROTOCOL: logger.debug('<<< HID SET PROTOCOL') - self.handle_set_protocol(pdu) # type: ignore[attr-defined] + self.handle_set_protocol(pdu) elif message_type == Message.MessageType.DATA: logger.debug('<<< HID CONTROL DATA') self.emit('control_data', pdu) @@ -326,9 +334,7 @@ class HID(EventEmitter): logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') else: logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') - self.send_handshake_message( # type: ignore[attr-defined] - Message.Handshake.ERR_UNSUPPORTED_REQUEST - ) + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def on_intr_pdu(self, pdu: bytes) -> None: logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') @@ -379,10 +385,10 @@ class Device(HID): def __init__(self, device: device.Device) -> None: super().__init__(device, HID.Role.DEVICE) - self.get_report_cb = None - self.set_report_cb = None - self.get_protocol_cb = None - self.set_protocol_cb = None + get_report_cb: Optional[Callable[[int, int, int], None]] = None + set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None + get_protocol_cb: Optional[Callable[[], None]] = None + set_protocol_cb: Optional[Callable[[int, bytes], None]] = None def send_handshake_message(self, result_code: int) -> None: msg = SendHandshakeMessage(result_code) @@ -418,7 +424,7 @@ class Device(HID): data = bytearray() data.append(report_id) data.extend(ret.data) - if len(data) < self.l2cap_ctrl_channel.mtu: + if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr] self.send_control_data(report_type=report_type, data=data) else: self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) @@ -441,10 +447,8 @@ class Device(HID): report_type = pdu[0] & 0x03 report_id = pdu[1] report_data = pdu[2:] - report_size = len(pdu[1:]) - ret = self.set_report_cb( - report_id, report_type, report_size, report_data - ) # type: ignore + report_size = len(report_data) + 1 + ret = self.set_report_cb(report_id, report_type, report_size, report_data) if ret.status == self.GetSetReturn.SUCCESS: self.send_handshake_message(Message.Handshake.SUCCESSFUL) elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: diff --git a/examples/run_hid_device.py b/examples/run_hid_device.py index e0ac281..0a5b1d0 100644 --- a/examples/run_hid_device.py +++ b/examples/run_hid_device.py @@ -428,63 +428,62 @@ class DeviceData: deviceData = DeviceData() # ----------------------------------------------------------------------------- -async def keyboard_device(hid_device, command): +async def keyboard_device(hid_device): - if command == 'web': - # Start a Websocket server to receive events from a web page - async def serve(websocket, _path): - global deviceData - while True: - try: - message = await websocket.recv() - print('Received: ', str(message)) - parsed = json.loads(message) - message_type = parsed['type'] - if message_type == 'keydown': - # Only deal with keys a to z for now - key = parsed['key'] - if len(key) == 1: - code = ord(key) - if ord('a') <= code <= ord('z'): - hid_code = 0x04 + code - ord('a') - deviceData.keyboardData = bytearray( - [ - 0x01, - 0x00, - 0x00, - hid_code, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - hid_device.send_data(deviceData.keyboardData) - elif message_type == 'keyup': - deviceData.keyboardData = bytearray( - [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - hid_device.send_data(deviceData.keyboardData) - elif message_type == "mousemove": - # logical min and max values - log_min = -127 - log_max = 127 - x = parsed['x'] - y = parsed['y'] - # limiting x and y values within logical max and min range - x = max(log_min, min(log_max, x)) - y = max(log_min, min(log_max, y)) - x_cord = x.to_bytes(signed=True) - y_cord = y.to_bytes(signed=True) - deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord - hid_device.send_data(deviceData.mouseData) - except websockets.exceptions.ConnectionClosedOK: - pass + # Start a Websocket server to receive events from a web page + async def serve(websocket, _path): + global deviceData + while True: + try: + message = await websocket.recv() + print('Received: ', str(message)) + parsed = json.loads(message) + message_type = parsed['type'] + if message_type == 'keydown': + # Only deal with keys a to z for now + key = parsed['key'] + if len(key) == 1: + code = ord(key) + if ord('a') <= code <= ord('z'): + hid_code = 0x04 + code - ord('a') + deviceData.keyboardData = bytearray( + [ + 0x01, + 0x00, + 0x00, + hid_code, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + ] + ) + hid_device.send_data(deviceData.keyboardData) + elif message_type == 'keyup': + deviceData.keyboardData = bytearray( + [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] + ) + hid_device.send_data(deviceData.keyboardData) + elif message_type == "mousemove": + # logical min and max values + log_min = -127 + log_max = 127 + x = parsed['x'] + y = parsed['y'] + # limiting x and y values within logical max and min range + x = max(log_min, min(log_max, x)) + y = max(log_min, min(log_max, y)) + x_cord = x.to_bytes(signed=True) + y_cord = y.to_bytes(signed=True) + deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord + hid_device.send_data(deviceData.mouseData) + except websockets.exceptions.ConnectionClosedOK: + pass - # pylint: disable-next=no-member - await websockets.serve(serve, 'localhost', 8989) - await asyncio.get_event_loop().create_future() + # pylint: disable-next=no-member + await websockets.serve(serve, 'localhost', 8989) + await asyncio.get_event_loop().create_future() # ----------------------------------------------------------------------------- @@ -511,10 +510,10 @@ async def main(): if connection is not None: await connection.disconnect() - def on_hid_data_cb(pdu): + def on_hid_data_cb(pdu: bytes): print(f'Received Data, PDU: {pdu.hex()}') - def on_get_report_cb(report_id, report_type, buffer_size): + def on_get_report_cb(report_id: int, report_type: int, buffer_size: int): retValue = hid_device.GetSetStatus() print( "GET_REPORT report_id: " @@ -552,7 +551,9 @@ async def main(): return retValue - def on_set_report_cb(report_id, report_type, report_size, data): + def on_set_report_cb( + report_id: int, report_type: int, report_size: int, data: bytes + ): retValue = hid_device.GetSetStatus() print( "SET_REPORT report_id: " @@ -586,7 +587,7 @@ async def main(): retValue.status = hid_device.GetSetReturn.SUCCESS return retValue - def on_set_protocol_cb(protocol): + def on_set_protocol_cb(protocol: int): retValue = hid_device.GetSetStatus() # We do not support SET_PROTOCOL. print("SET_PROTOCOL report_id: " + str(protocol)) @@ -731,24 +732,14 @@ async def main(): else: print("Invalid option selected.") - if len(sys.argv) > 3: + if (len(sys.argv) > 3) and (command == 'test-mode'): + # Test mode for PTS/Unit testing command = sys.argv[3] - - if command == 'test-mode': - # Enabling menu for testing - await menu() - - elif command == 'web': - # Run as a keyboard and mouse device - await keyboard_device(hid_device, command) - - else: - print("Command incorrect. Switching to default: web") - await keyboard_device(hid_device, 'web') - + await menu() else: # default option is using keyboard.html (web) - await keyboard_device(hid_device, 'web') + print("Command incorrect. Switching to default") + await keyboard_device(hid_device) await hci_source.wait_for_termination()