diff --git a/bumble/hid.py b/bumble/hid.py index 772a6bab..8c18b2aa 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -16,17 +16,20 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations +from dataclasses import dataclass import logging import asyncio +import enum from pyee import EventEmitter -from typing import Optional, Tuple, Callable, Dict, Union -from .device import Device, Connection +from typing import Optional, Tuple, Callable, Dict, Union, TYPE_CHECKING from . import core, l2cap # type: ignore from .colors import color # type: ignore from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError # type: ignore +if TYPE_CHECKING: + from bumble.device import Device, Connection # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -38,96 +41,128 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # fmt: off -HID_CONTROL_PSM = 0x0011 -HID_INTERRUPT_PSM = 0x0013 - -# HIDP message types -HID_HANDSHAKE = 0x00 -HID_CONTROL = 0x01 -HID_GET_REPORT = 0x04 -HID_SET_REPORT = 0x05 -HID_GET_PROTOCOL = 0x06 -HID_SET_PROTOCOL = 0x07 -HID_DATA = 0x0A - -# Report types -HID_OTHER_REPORT = 0x00 -HID_INPUT_REPORT = 0x01 -HID_OUTPUT_REPORT = 0x02 -HID_FEATURE_REPORT = 0x03 - -# Handshake parameters -HANDSHAKE_SUCCESSFUL = 0x00 -HANDSHAKE_NOT_READY = 0x01 -HANDSHAKE_ERR_INVALID_REPORT_ID = 0x02 -HANDSHAKE_ERR_UNSUPPORTED_REQUEST = 0x03 -HANDSHAKE_ERR_UNKNOWN = 0x0E -HANDSHAKE_ERR_FATAL = 0x0F - -# Protocol modes -HID_BOOT_PROTOCOL_MODE = 0x00 -HID_REPORT_PROTOCOL_MODE = 0x01 - -# Control Operations -HID_SUSPEND = 0x03 -HID_EXIT_SUSPEND = 0x04 -HID_VIRTUAL_CABLE_UNPLUG = 0x05 -class HIDPacket(): + +class Message(): + class HIDPsm(enum.IntEnum): + HID_CONTROL_PSM = 0x0011 + HID_INTERRUPT_PSM = 0x0013 + + # Report types + class ReportType(enum.IntEnum): + HID_OTHER_REPORT = 0x00 + HID_INPUT_REPORT = 0x01 + HID_OUTPUT_REPORT = 0x02 + HID_FEATURE_REPORT = 0x03 + + # Handshake parameters + class HandshakeState(enum.IntEnum): + HANDSHAKE_SUCCESSFUL = 0x00 + HANDSHAKE_NOT_READY = 0x01 + HANDSHAKE_ERR_INVALID_REPORT_ID = 0x02 + HANDSHAKE_ERR_UNSUPPORTED_REQUEST = 0x03 + HANDSHAKE_ERR_UNKNOWN = 0x0E + HANDSHAKE_ERR_FATAL = 0x0F + + class Type(enum.IntEnum): + HID_HANDSHAKE = 0x00 + HID_CONTROL = 0x01 + HID_GET_REPORT = 0x04 + HID_SET_REPORT = 0x05 + HID_GET_PROTOCOL = 0x06 + HID_SET_PROTOCOL = 0x07 + HID_DATA = 0x0A + + # Protocol modes + class ProtocolMode(enum.IntEnum): + HID_BOOT_PROTOCOL_MODE = 0x00 + HID_REPORT_PROTOCOL_MODE = 0x01 + + # Control Operations + class ControlCommand(enum.IntEnum): + HID_SUSPEND = 0x03 + HID_EXIT_SUSPEND = 0x04 + HID_VIRTUAL_CABLE_UNPLUG = 0x05 + + + # HIDP message types +@dataclass +class GetReportMessage(Message): + report_type : int + report_id : int + buffer_size : int + ''' def __init__(self, report_type: Optional[int] = None, report_id: Optional[int] = None, buffer_size: Optional[int] = None, - protocol_mode: Optional[int] = None, - data: Optional[bytes] = None) -> None: - + ): self.report_type = report_type self.report_id = report_id self.buffer_size = buffer_size - self.protocol_mode = protocol_mode - self.data = data - - def to_bytes_gr(self) -> bytes: - if(self.report_type == HID_OTHER_REPORT): + ''' + def __bytes__(self) -> bytes: + if(self.report_type == Message.ReportType.HID_OTHER_REPORT): param = self.report_type else: param = 0x08 | self.report_type - header = ((HID_GET_REPORT << 4) | param) + header = ((Message.Type.HID_GET_REPORT << 4) | param) packet_bytes = bytearray() packet_bytes.append(header) packet_bytes.append(self.report_id) packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)]) return bytes(packet_bytes) - def to_bytes_sr(self) -> bytes: - header = ((HID_SET_REPORT << 4) | self.report_type) +class SetReportMessage(Message): + + def __init__(self, + report_type: int, + data : bytes): + self.report_type = report_type + self.data = data + + def __bytes__(self) -> bytes: + header = ((Message.Type.HID_SET_REPORT << 4) | self.report_type) packet_bytes = bytearray() packet_bytes.append(header) packet_bytes.extend(self.data) return bytes(packet_bytes) - def to_bytes_gp(self) -> bytes: - header = (HID_GET_PROTOCOL << 4) +class GetProtocolMessage(Message): + + + def __bytes__(self) -> bytes: + header = (Message.Type.HID_GET_PROTOCOL << 4) packet_bytes = bytearray() packet_bytes.append(header) return bytes(packet_bytes) - def to_bytes_sp(self) -> bytes: - header = (HID_SET_PROTOCOL << 4 | self.protocol_mode) +class SetProtocolMessage(Message): + + def __init__(self, protocol_mode: int): + self.protocol_mode = protocol_mode + + + def __bytes__(self) -> bytes: + header = (Message.Type.HID_SET_PROTOCOL << 4 | self.protocol_mode) packet_bytes = bytearray() packet_bytes.append(header) packet_bytes.append(self.protocol_mode) return bytes(packet_bytes) - def to_bytes_send_data(self) -> bytes: - header = ((HID_DATA << 4) | HID_OUTPUT_REPORT) +class SendData(Message): + def __init__(self, data : bytes): + self.data = data + + def __bytes__(self) -> bytes: + header = ((Message.Type.HID_DATA << 4) | Message.ReportType.HID_OUTPUT_REPORT) packet_bytes = bytearray() packet_bytes.append(header) packet_bytes.extend(self.data) return bytes(packet_bytes) # ----------------------------------------------------------------------------- -class HIDHost(EventEmitter): +class Host(EventEmitter): l2cap_channel: Optional[l2cap.Channel] def __init__(self, device: Device, connection: Connection) -> None: @@ -138,17 +173,17 @@ class HIDHost(EventEmitter): self.l2cap_intr_channel = None # Register ourselves with the L2CAP channel manager - device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection) - device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection) + device.register_l2cap_server(Message.HIDPsm.HID_CONTROL_PSM, self.on_connection) + device.register_l2cap_server(Message.HIDPsm.HID_INTERRUPT_PSM, self.on_connection) async def connect_control_channel(self) -> None: # Create a new L2CAP connection - control channel try: self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( - self.connection, HID_CONTROL_PSM + self.connection, Message.HIDPsm.HID_CONTROL_PSM ) except ProtocolError as error: - logger.error(f'L2CAP connection failed: {error}') + logging.exception(f'L2CAP connection failed: {error}') raise assert self.l2cap_ctrl_channel is not None @@ -159,10 +194,10 @@ class HIDHost(EventEmitter): # Create a new L2CAP connection - interrupt channel try: self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( - self.connection, HID_INTERRUPT_PSM + self.connection, Message.HIDPsm.HID_INTERRUPT_PSM ) except ProtocolError as error: - logger.error(f'L2CAP connection failed: {error}') + logging.exception(f'L2CAP connection failed: {error}') raise assert self.l2cap_intr_channel is not None @@ -173,18 +208,24 @@ class HIDHost(EventEmitter): if self.l2cap_intr_channel is None: raise InvalidStateError('invalid state') await self.l2cap_intr_channel.disconnect() # type: ignore + channel = self.l2cap_intr_channel + self.l2cap_intr_channel = None + await channel.disconnect() # type: ignore async def disconnect_control_channel(self) -> None: if self.l2cap_ctrl_channel is None: raise InvalidStateError('invalid state') await self.l2cap_ctrl_channel.disconnect() # type: ignore + channel = self.l2cap_ctrl_channel + self.l2cap_ctrl_channel = None + await channel.disconnect() # type: ignore def on_connection(self, l2cap_channel: l2cap.Channel) -> None: logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: - if l2cap_channel.psm == HID_CONTROL_PSM: + if l2cap_channel.psm == Message.HIDPsm.HID_CONTROL_PSM: self.l2cap_ctrl_channel = l2cap_channel self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu else: @@ -197,15 +238,22 @@ class HIDHost(EventEmitter): # Here we will receive all kinds of packets, parse and then call respective callbacks message_type = pdu[0] >> 4 param = pdu[0] & 0x0f - if message_type == HID_HANDSHAKE : + + for command in Message.ControlCommand.__members__items(): + if param == command: + logger.debug(f'<<< ', command + pdu) + self.handle_handshake(param) + self.emit(command, pdu) + ''' + if message_type == Message.Type.HID_HANDSHAKE : logger.debug('<<< HID HANDSHAKE') self.handle_handshake(param) self.emit('handshake', pdu) - elif message_type == HID_DATA : + elif message_type == Message.Type.HID_DATA : logger.debug('<<< HID CONTROL DATA') self.emit('data', pdu) - elif message_type == HID_CONTROL : - if param == HID_SUSPEND : + elif message_type == Message.Type.HID_CONTROL : + if param == Message.ControlCommand.HID_SUSPEND : logger.debug('<<< HID SUSPEND') self.emit('suspend', pdu) elif param == HID_EXIT_SUSPEND : @@ -219,34 +267,35 @@ class HIDHost(EventEmitter): else: logger.debug('<<< HID CONTROL DATA') self.emit('data', pdu) + ''' def on_intr_pdu(self, pdu: bytes) -> None: logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') self.emit("data", pdu) def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: - msg = HIDPacket(report_type = report_type , report_id = report_id , buffer_size = buffer_size) - hid_packet = msg.to_bytes_gr() - logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_packet.hex()}') - self.send_pdu_on_ctrl(hid_packet) # type: ignore + msg = GetReportMessage(report_type = report_type , report_id = report_id , buffer_size = buffer_size) + hid_message = msg.__bytes__() + logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') + self.send_pdu_on_ctrl(hid_message) # type: ignore def set_report(self, report_type: int, data: bytes): - msg = HIDPacket(report_type= report_type,data = data) - hid_packet = msg.to_bytes_sr() - logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_packet.hex()}') - self.send_pdu_on_ctrl(hid_packet) # type: ignore + msg = SetReportMessage(report_type= report_type,data = data) + hid_message = msg.__bytes__() + logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') + self.send_pdu_on_ctrl(hid_message) # type: ignore def get_protocol(self): - msg = HIDPacket() - hid_packet = msg.to_bytes_gp() - logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_packet.hex()}') - self.send_pdu_on_ctrl(hid_packet) # type: ignore + msg = GetProtocolMessage() + hid_message = msg.__bytes__() + logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') + self.send_pdu_on_ctrl(hid_message) # type: ignore def set_protocol(self, protocol_mode: int): - msg = HIDPacket(protocol_mode= protocol_mode) - hid_packet = msg.to_bytes_sp() - logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_packet.hex()}') - self.send_pdu_on_ctrl(hid_packet) # type: ignore + msg = SetProtocolMessage(protocol_mode= protocol_mode) + hid_message = msg.__bytes__() + logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') + self.send_pdu_on_ctrl(hid_message) # type: ignore def send_pdu_on_ctrl(self, msg: bytes) -> None: self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore @@ -255,30 +304,34 @@ class HIDHost(EventEmitter): self.l2cap_intr_channel.send_pdu(msg) # type: ignore def send_data(self, data): - msg = HIDPacket(data= data) - hid_packet = msg.to_bytes_send_data() - logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_packet.hex()}') - self.send_pdu_on_intr(hid_packet) # type: ignore + msg = Message(data= data) + hid_message = msg.__bytes__() + logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') + self.send_pdu_on_intr(hid_message) # type: ignore def suspend(self): - header = (HID_CONTROL << 4 | HID_SUSPEND) + header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_SUSPEND) msg = bytearray([header]) logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{msg.hex()}') self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore def exit_suspend(self): - header = (HID_CONTROL << 4 | HID_EXIT_SUSPEND) + header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_EXIT_SUSPEND) msg = bytearray([header]) logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{msg.hex()}') self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore def virtual_cable_unplug(self): - header = (HID_CONTROL << 4 | HID_VIRTUAL_CABLE_UNPLUG) + header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_VIRTUAL_CABLE_UNPLUG) msg = bytearray([header]) logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {msg.hex()}') self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore - def handle_handshake(self, param: int): + def handle_handshake(self, param: Message.HandshakeState): + for state in Message.HandshakeState.__members__items(): + if param == state: + logger.debug(f'<<< HID HANDSHAKE: ', state) + ''' if param == HANDSHAKE_SUCCESSFUL : logger.debug(f'<<< HID HANDSHAKE: SUCCESSFUL') elif param == HANDSHAKE_NOT_READY : @@ -293,3 +346,4 @@ class HIDHost(EventEmitter): logger.warning(f'<<< HID HANDSHAKE: ERR_FATAL') else: # 0x5-0xD = Reserved logger.warning("<<< HID HANDSHAKE: RESERVED VALUE") + ''' \ No newline at end of file diff --git a/examples/run_hid_host.py b/examples/run_hid_host.py index ad92b270..efb9768b 100644 --- a/examples/run_hid_host.py +++ b/examples/run_hid_host.py @@ -33,7 +33,7 @@ from bumble.core import ( BT_BR_EDR_TRANSPORT, ) from bumble.hci import Address -from bumble.hid import HIDHost, HID_INPUT_REPORT, HID_OTHER_REPORT, HID_BOOT_PROTOCOL_MODE, HID_REPORT_PROTOCOL_MODE +from bumble.hid import Host, Message from bumble.sdp import ( Client as SDP_Client, DataElement, @@ -243,13 +243,13 @@ async def main(): report_length = len(pdu[1:]) report_id = pdu[1] - if (report_type != HID_OTHER_REPORT): + if (report_type != Message.ReportType.HID_OTHER_REPORT): print(color(f' Report type = {report_type}, Report length = {report_length}, Report id = {report_id}', 'blue', None, 'bold')) if ((report_length <= 1) or (report_id == 0)): return - if report_type == HID_INPUT_REPORT: + if report_type == Message.ReportType.HID_INPUT_REPORT: ReportParser.parse_input_report(pdu[1:]) #type: ignore async def handle_virtual_cable_unplug(): @@ -290,7 +290,7 @@ async def main(): # Create HID host and start it print('@@@ Starting HID Host...') - hid_host = HIDHost(device, connection) + hid_host = Host(device, connection) # Register for HID data call back hid_host.on('data', on_hid_data_cb) @@ -383,10 +383,10 @@ async def main(): choice1 = choice1.decode('utf-8').strip() if choice1 == '0': - hid_host.set_protocol(HID_BOOT_PROTOCOL_MODE) + hid_host.set_protocol(Message.ProtocolMode.HID_BOOT_PROTOCOL_MODE) elif choice1 == '1': - hid_host.set_protocol(HID_REPORT_PROTOCOL_MODE) + hid_host.set_protocol(Message.ProtocolMode.HID_REPORT_PROTOCOL_MODE) else: print('Incorrect option selected')