diff --git a/bumble/hid.py b/bumble/hid.py index 0dd87cd..daf0d70 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -19,17 +19,15 @@ from __future__ import annotations from dataclasses import dataclass import logging import enum +import struct from pyee import EventEmitter from typing import Optional, TYPE_CHECKING -from bumble import l2cap +from bumble import l2cap, device from bumble.colors import color from bumble.core import InvalidStateError, ProtocolError -if TYPE_CHECKING: - from bumble.device import Device, Connection - # ----------------------------------------------------------------------------- # Logging @@ -105,10 +103,11 @@ class GetReportMessage(Message): if self.buffer_size == 0: return self.header(self.report_type) + packet_bytes else: - packet_bytes.extend( - [(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)] + return ( + self.header(0x08 | self.report_type) + + packet_bytes + + struct.pack(" bytes: - packet_bytes = bytearray() - - packet_bytes.extend(self.data) - return self.header(self.report_type) + packet_bytes + return self.header(self.report_type) + self.data @dataclass @@ -151,17 +147,6 @@ class SetProtocolMessage(Message): return self.header(self.protocol_mode) -@dataclass -class GetProtocolReplyMessage(Message): - protocol_mode: int - message_type = Message.MessageType.DATA - - def __bytes__(self) -> bytes: - packet_bytes = bytearray() - packet_bytes.append(self.protocol_mode) - return self.header(Message.ReportType.OTHER_REPORT) + packet_bytes - - @dataclass class Suspend(Message): message_type = Message.MessageType.CONTROL @@ -215,7 +200,7 @@ class HID(EventEmitter): HOST = 0x00 DEVICE = 0x01 - def __init__(self, device: Device, role: int) -> None: + def __init__(self, device: device.Device, role: Role) -> None: super().__init__() self.device = device self.connection = None @@ -273,11 +258,11 @@ class HID(EventEmitter): self.l2cap_ctrl_channel = None await channel.disconnect() - def on_device_connection(self, connection: Connection) -> None: + def on_device_connection(self, connection: device.Connection) -> None: self.connection = connection # type: ignore[assignment] self.remote_device_bd_address = ( - connection.peer_address - ) # type: ignore[assignment] + connection.peer_address # type: ignore[assignment] + ) connection.on('disconnection', self.on_disconnection) def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: @@ -341,9 +326,9 @@ class HID(EventEmitter): logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') else: logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') - self.send_handshake_message( + self.send_handshake_message( # type: ignore[attr-defined] Message.Handshake.ERR_UNSUPPORTED_REQUEST - ) # type: ignore[attr-defined] + ) def on_intr_pdu(self, pdu: bytes) -> None: logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') @@ -378,7 +363,7 @@ class HID(EventEmitter): # ----------------------------------------------------------------------------- -class Device(HID): # type: ignore[no-redef] +class Device(HID): class GetSetReturn(enum.IntEnum): FAILURE = 0x00 REPORT_ID_NOT_FOUND = 0x01 @@ -389,10 +374,10 @@ class Device(HID): # type: ignore[no-redef] class GetSetStatus: def __init__(self) -> None: - self.data: bytes + self.data = bytearray() self.status = 0 - def __init__(self, device: Device) -> None: + def __init__(self, device: device.Device) -> None: super().__init__(device, HID.Role.DEVICE) self.get_report_cb = None self.set_report_cb = None @@ -412,7 +397,10 @@ class Device(HID): # type: ignore[no-redef] self.send_pdu_on_ctrl(hid_message) def handle_get_report(self, pdu: bytes): - ret = self.GetSetStatus() + if self.get_report_cb is None: + logger.debug("GetReport callback not registered !!") + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + return report_type = pdu[0] & 0x03 buffer_flag = (pdu[0] & 0x08) >> 3 report_id = pdu[1] @@ -422,32 +410,23 @@ class Device(HID): # type: ignore[no-redef] else: buffer_size = 0 - if self.get_report_cb != None: - ret = self.get_report_cb( - report_id, report_type, buffer_size - ) # type: ignore + ret = self.get_report_cb(report_id, report_type, buffer_size) - if ret.status == self.GetSetReturn.FAILURE: - self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) - elif ret.status == self.GetSetReturn.SUCCESS: - data = bytearray() - data.append(report_id) - data.extend(ret.data) - if ( - len(data) < self.l2cap_ctrl_channel.mtu - ): # type: ignore[union-attr] - self.send_control_data(report_type=report_type, data=data) - else: - self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - - elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: - self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) - elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: + if ret.status == self.GetSetReturn.FAILURE: + self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) + elif ret.status == self.GetSetReturn.SUCCESS: + data = bytearray() + data.append(report_id) + data.extend(ret.data) + if len(data) < self.l2cap_ctrl_channel.mtu: + self.send_control_data(report_type=report_type, data=data) + else: self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - else: - logger.debug("GetReport callback not registered !!") + elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: + self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) + elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: + self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) + elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def register_get_report_cb(self, cb): @@ -455,56 +434,55 @@ class Device(HID): # type: ignore[no-redef] logger.debug("GetReport callback registered successfully") def handle_set_report(self, pdu: bytes): - if self.set_report_cb != None: - report_type = pdu[0] & 0x03 - report_id = pdu[1] - report_data = pdu[2:] - report_size = len(pdu[1:]) - ret = self.set_report_cb( - report_id, report_type, report_size, report_data - ) # type: ignore - if ret.status == self.GetSetReturn.SUCCESS: - self.send_handshake_message(Message.Handshake.SUCCESSFUL) - elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: - self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: - self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) - else: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - else: + if self.set_report_cb is None: logger.debug("SetReport callback not registered !!") self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + return + report_type = pdu[0] & 0x03 + report_id = pdu[1] + report_data = pdu[2:] + report_size = len(pdu[1:]) + ret = self.set_report_cb( + report_id, report_type, report_size, report_data + ) # type: ignore + if ret.status == self.GetSetReturn.SUCCESS: + self.send_handshake_message(Message.Handshake.SUCCESSFUL) + elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: + self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) + elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: + self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) + else: + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def register_set_report_cb(self, cb): self.set_report_cb = cb logger.debug("SetReport callback registered successfully") def handle_get_protocol(self, pdu: bytes): - ret = self.GetSetStatus() - if self.get_protocol_cb != None: - ret = self.get_protocol_cb() # type: ignore - if ret.status == self.GetSetReturn.SUCCESS: - self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) - return - else: + if self.get_protocol_cb is None: logger.debug("GetProtocol callback not registered !!") - - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + return + ret = self.get_protocol_cb() + if ret.status == self.GetSetReturn.SUCCESS: + self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) + else: + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def register_get_protocol_cb(self, cb): self.get_protocol_cb = cb logger.debug("GetProtocol callback registered successfully") def handle_set_protocol(self, pdu: bytes): - ret = self.GetSetStatus() - if self.set_protocol_cb != None: - ret = self.set_protocol_cb(pdu[0] & 0x01) # type: ignore - if ret.status == self.GetSetReturn.SUCCESS: - return - else: + if self.set_protocol_cb is None: logger.debug("SetProtocol callback not registered !!") - - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + return + ret = self.set_protocol_cb(pdu[0] & 0x01) + if ret.status == self.GetSetReturn.SUCCESS: + self.send_handshake_message(Message.Handshake.SUCCESSFUL) + else: + self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) def register_set_protocol_cb(self, cb): self.set_protocol_cb = cb @@ -513,7 +491,7 @@ class Device(HID): # type: ignore[no-redef] # ----------------------------------------------------------------------------- class Host(HID): - def __init__(self, device: Device) -> None: + def __init__(self, device: device.Device) -> None: super().__init__(device, HID.Role.HOST) def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: