From 3ab64ce00ddfe806613e73a159ff808e389e90a7 Mon Sep 17 00:00:00 2001 From: skarnataki Date: Mon, 16 Oct 2023 07:13:30 +0000 Subject: [PATCH] Fixed lint and pre-commit errors. --- bumble/hid.py | 178 ++++++++++++++++++++-------------- bumble/sdp.py | 2 +- examples/hid_report_parser.py | 4 +- 3 files changed, 108 insertions(+), 76 deletions(-) diff --git a/bumble/hid.py b/bumble/hid.py index 1a8e5f9..e4d6a77 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -41,69 +41,71 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -# fmt: off -HID_CONTROL_PSM = 0x0011 -HID_INTERRUPT_PSM = 0x0013 +# fmt: on +HID_CONTROL_PSM = 0x0011 +HID_INTERRUPT_PSM = 0x0013 -class Message(): +class Message: message_type: MessageType # Report types class ReportType(enum.IntEnum): - OTHER_REPORT = 0x00 - INPUT_REPORT = 0x01 - OUTPUT_REPORT = 0x02 - FEATURE_REPORT = 0x03 + OTHER_REPORT = 0x00 + INPUT_REPORT = 0x01 + OUTPUT_REPORT = 0x02 + FEATURE_REPORT = 0x03 # Handshake parameters class Handshake(enum.IntEnum): - SUCCESSFUL = 0x00 - NOT_READY = 0x01 - ERR_INVALID_REPORT_ID = 0x02 + SUCCESSFUL = 0x00 + NOT_READY = 0x01 + ERR_INVALID_REPORT_ID = 0x02 ERR_UNSUPPORTED_REQUEST = 0x03 - ERR_UNKNOWN = 0x0E - ERR_FATAL = 0x0F + ERR_UNKNOWN = 0x0E + ERR_FATAL = 0x0F - #Message Type + # Message Type class MessageType(enum.IntEnum): - HANDSHAKE = 0x00 - CONTROL = 0x01 - GET_REPORT = 0x04 - SET_REPORT = 0x05 - GET_PROTOCOL = 0x06 - SET_PROTOCOL = 0x07 - DATA = 0x0A + HANDSHAKE = 0x00 + CONTROL = 0x01 + GET_REPORT = 0x04 + SET_REPORT = 0x05 + GET_PROTOCOL = 0x06 + SET_PROTOCOL = 0x07 + DATA = 0x0A # Protocol modes class ProtocolMode(enum.IntEnum): - BOOT_PROTOCOL = 0x00 - REPORT_PROTOCOL = 0x01 + BOOT_PROTOCOL = 0x00 + REPORT_PROTOCOL = 0x01 # Control Operations class ControlCommand(enum.IntEnum): - SUSPEND = 0x03 - EXIT_SUSPEND = 0x04 + SUSPEND = 0x03 + EXIT_SUSPEND = 0x04 VIRTUAL_CABLE_UNPLUG = 0x05 # Class Method to derive header @classmethod - def header( cls , lower_bits : int = 0x00 ) -> bytes : + def header(cls, lower_bits: int = 0x00) -> bytes: return bytes([(cls.message_type << 4) | lower_bits]) # HIDP messages @dataclass class GetReportMessage(Message): - report_type : int - report_id : int - buffer_size : int + report_type: int + report_id: int + buffer_size: int message_type = Message.MessageType.GET_REPORT def __bytes__(self) -> bytes: packet_bytes = bytearray() packet_bytes.append(self.report_id) - packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)]) - if(self.report_type == Message.ReportType.OTHER_REPORT): + packet_bytes.extend( + [(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)] + ) + if self.report_type == Message.ReportType.OTHER_REPORT: return self.header(self.report_type) + packet_bytes else: return self.header(0x08 | self.report_type) + packet_bytes @@ -112,7 +114,7 @@ class GetReportMessage(Message): @dataclass class SetReportMessage(Message): report_type: int - data : bytes + data: bytes message_type = Message.MessageType.SET_REPORT def __bytes__(self) -> bytes: @@ -136,9 +138,33 @@ class SetProtocolMessage(Message): return self.header(self.protocol_mode) +@dataclass +class Suspend(Message): + message_type = Message.MessageType.CONTROL + + def __bytes__(self) -> bytes: + return self.header(Message.ControlCommand.SUSPEND) + + +@dataclass +class ExitSuspend(Message): + message_type = Message.MessageType.CONTROL + + def __bytes__(self) -> bytes: + return self.header(Message.ControlCommand.EXIT_SUSPEND) + + +@dataclass +class VirtualCableUnplug(Message): + message_type = Message.MessageType.CONTROL + + def __bytes__(self) -> bytes: + return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) + + @dataclass class SendData(Message): - data : bytes + data: bytes message_type = Message.MessageType.DATA def __bytes__(self) -> bytes: @@ -147,13 +173,15 @@ class SendData(Message): # ----------------------------------------------------------------------------- class Host(EventEmitter): - l2cap_channel: Optional[l2cap.Channel] + l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] + l2cap_intr_channel: Optional[l2cap.ClassicChannel] def __init__(self, device: Device, connection: Connection) -> None: super().__init__() self.device = device self.connection = connection - self.l2cap_ctrl_channel= None + + self.l2cap_ctrl_channel = None self.l2cap_intr_channel = None # Register ourselves with the L2CAP channel manager @@ -166,8 +194,8 @@ class Host(EventEmitter): self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( self.connection, HID_CONTROL_PSM ) - except ProtocolError as error: - logging.exception(f'L2CAP connection failed: {error}') + except ProtocolError: + logging.exception(f'L2CAP connection failed.') raise assert self.l2cap_ctrl_channel is not None @@ -180,8 +208,8 @@ class Host(EventEmitter): self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( self.connection, HID_INTERRUPT_PSM ) - except ProtocolError as error: - logging.exception(f'L2CAP connection failed: {error}') + except ProtocolError: + logging.exception(f'L2CAP connection failed.') raise assert self.l2cap_intr_channel is not None @@ -193,48 +221,48 @@ class Host(EventEmitter): raise InvalidStateError('invalid state') channel = self.l2cap_intr_channel self.l2cap_intr_channel = None - await channel.disconnect() # type: ignore + await channel.disconnect() async def disconnect_control_channel(self) -> None: if self.l2cap_ctrl_channel is None: raise InvalidStateError('invalid state') channel = self.l2cap_ctrl_channel self.l2cap_ctrl_channel = None - await channel.disconnect() # type: ignore + await channel.disconnect() - def on_connection(self, l2cap_channel: l2cap.Channel) -> None: + def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) - def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: + def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: if l2cap_channel.psm == HID_CONTROL_PSM: - self.l2cap_ctrl_channel = l2cap_channel # type: ignore - self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu # type: ignore + self.l2cap_ctrl_channel = l2cap_channel + self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu else: - self.l2cap_intr_channel = l2cap_channel # type: ignore - self.l2cap_intr_channel.sink = self.on_intr_pdu # type: ignore + self.l2cap_intr_channel = l2cap_channel + self.l2cap_intr_channel.sink = self.on_intr_pdu logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') def on_ctrl_pdu(self, pdu: bytes) -> None: logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') # Here we will receive all kinds of packets, parse and then call respective callbacks message_type = pdu[0] >> 4 - param = pdu[0] & 0x0f + param = pdu[0] & 0x0F - if message_type == Message.MessageType.HANDSHAKE : + if message_type == Message.MessageType.HANDSHAKE: logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}') self.emit('handshake', Message.Handshake(param)) - elif message_type == Message.MessageType.DATA : + elif message_type == Message.MessageType.DATA: logger.debug('<<< HID CONTROL DATA') self.emit('data', pdu) - elif message_type == Message.MessageType.CONTROL : - if param == Message.ControlCommand.SUSPEND : + elif message_type == Message.MessageType.CONTROL: + if param == Message.ControlCommand.SUSPEND: logger.debug('<<< HID SUSPEND') self.emit('suspend', pdu) - elif param == Message.ControlCommand.EXIT_SUSPEND : + elif param == Message.ControlCommand.EXIT_SUSPEND: logger.debug('<<< HID EXIT SUSPEND') self.emit('exit_suspend', pdu) - elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG : + elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: logger.debug('<<< HID VIRTUAL CABLE UNPLUG') self.emit('virtual_cable_unplug') else: @@ -248,28 +276,30 @@ class Host(EventEmitter): self.emit("data", pdu) def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: - msg = GetReportMessage(report_type = report_type , report_id = report_id , buffer_size = buffer_size) + msg = GetReportMessage( + report_type=report_type, report_id=report_id, buffer_size=buffer_size + ) hid_message = bytes(msg) logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) # type: ignore + self.send_pdu_on_ctrl(hid_message) def set_report(self, report_type: int, data: bytes): - msg = SetReportMessage(report_type= report_type,data = data) + msg = SetReportMessage(report_type=report_type, data=data) hid_message = bytes(msg) logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) # type: ignore + self.send_pdu_on_ctrl(hid_message) def get_protocol(self): msg = GetProtocolMessage() hid_message = bytes(msg) logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) # type: ignore + self.send_pdu_on_ctrl(hid_message) def set_protocol(self, protocol_mode: int): - msg = SetProtocolMessage(protocol_mode= protocol_mode) + msg = SetProtocolMessage(protocol_mode=protocol_mode) hid_message = bytes(msg) logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) # type: ignore + self.send_pdu_on_ctrl(hid_message) def send_pdu_on_ctrl(self, msg: bytes) -> None: self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore @@ -281,22 +311,22 @@ class Host(EventEmitter): msg = SendData(data) hid_message = bytes(msg) logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') - self.send_pdu_on_intr(hid_message) # type: ignore + self.send_pdu_on_intr(hid_message) def suspend(self): - header = (Message.MessageType.CONTROL << 4 | Message.ControlCommand.SUSPEND) - msg = bytearray([header]) - logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{msg.hex()}') - self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore + msg = Suspend() + hid_message = bytes(msg) + logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}') + self.send_pdu_on_ctrl(msg) def exit_suspend(self): - header = (Message.MessageType.CONTROL << 4 | Message.ControlCommand.EXIT_SUSPEND) - msg = bytearray([header]) - logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{msg.hex()}') - self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore + msg = ExitSuspend() + hid_message = bytes(msg) + logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}') + self.send_pdu_on_ctrl(msg) def virtual_cable_unplug(self): - header = (Message.MessageType.CONTROL << 4 | Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) - msg = bytearray([header]) - logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {msg.hex()}') - self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore + msg = VirtualCableUnplug() + hid_message = bytes(msg) + logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}') + self.send_pdu_on_ctrl(msg) diff --git a/bumble/sdp.py b/bumble/sdp.py index bcd12cf..bc8303c 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -229,7 +229,7 @@ class DataElement: return DataElement(DataElement.UUID, value) @staticmethod - def text_string(value: str) -> DataElement: + def text_string(value: bytes) -> DataElement: return DataElement(DataElement.TEXT_STRING, value) @staticmethod diff --git a/examples/hid_report_parser.py b/examples/hid_report_parser.py index 1c891c6..e5f407f 100644 --- a/examples/hid_report_parser.py +++ b/examples/hid_report_parser.py @@ -139,7 +139,9 @@ class Mouse: # ------------------------------------------------------------------------------ class ReportParser: - def parse_input_report(self, input_report: bytes) -> None: # type: ignore + + @staticmethod + def parse_input_report(input_report: bytes) -> None: report_id = input_report[0] # pylint: disable=unsubscriptable-object report_length = len(input_report)