Commit to fix review comments for dataclass and subclass, shifting contants to Message Class

Commit for enum and dataclass
This commit is contained in:
skarnataki
2023-09-29 13:14:49 +00:00
committed by Lucas Abel
parent 5ce353bcde
commit 5ddee17411
2 changed files with 150 additions and 96 deletions

View File

@@ -16,17 +16,20 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import logging import logging
import asyncio import asyncio
import enum
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Tuple, Callable, Dict, Union from typing import Optional, Tuple, Callable, Dict, Union, TYPE_CHECKING
from .device import Device, Connection
from . import core, l2cap # type: ignore from . import core, l2cap # type: ignore
from .colors import color # type: ignore from .colors import color # type: ignore
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError # type: ignore from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError # type: ignore
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -38,96 +41,128 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off # fmt: off
HID_CONTROL_PSM = 0x0011
HID_INTERRUPT_PSM = 0x0013
# HIDP message types
HID_HANDSHAKE = 0x00
HID_CONTROL = 0x01
HID_GET_REPORT = 0x04
HID_SET_REPORT = 0x05
HID_GET_PROTOCOL = 0x06
HID_SET_PROTOCOL = 0x07
HID_DATA = 0x0A
# Report types
HID_OTHER_REPORT = 0x00
HID_INPUT_REPORT = 0x01
HID_OUTPUT_REPORT = 0x02
HID_FEATURE_REPORT = 0x03
# Handshake parameters
HANDSHAKE_SUCCESSFUL = 0x00
HANDSHAKE_NOT_READY = 0x01
HANDSHAKE_ERR_INVALID_REPORT_ID = 0x02
HANDSHAKE_ERR_UNSUPPORTED_REQUEST = 0x03
HANDSHAKE_ERR_UNKNOWN = 0x0E
HANDSHAKE_ERR_FATAL = 0x0F
# Protocol modes
HID_BOOT_PROTOCOL_MODE = 0x00
HID_REPORT_PROTOCOL_MODE = 0x01
# Control Operations
HID_SUSPEND = 0x03
HID_EXIT_SUSPEND = 0x04
HID_VIRTUAL_CABLE_UNPLUG = 0x05
class HIDPacket():
class Message():
class HIDPsm(enum.IntEnum):
HID_CONTROL_PSM = 0x0011
HID_INTERRUPT_PSM = 0x0013
# Report types
class ReportType(enum.IntEnum):
HID_OTHER_REPORT = 0x00
HID_INPUT_REPORT = 0x01
HID_OUTPUT_REPORT = 0x02
HID_FEATURE_REPORT = 0x03
# Handshake parameters
class HandshakeState(enum.IntEnum):
HANDSHAKE_SUCCESSFUL = 0x00
HANDSHAKE_NOT_READY = 0x01
HANDSHAKE_ERR_INVALID_REPORT_ID = 0x02
HANDSHAKE_ERR_UNSUPPORTED_REQUEST = 0x03
HANDSHAKE_ERR_UNKNOWN = 0x0E
HANDSHAKE_ERR_FATAL = 0x0F
class Type(enum.IntEnum):
HID_HANDSHAKE = 0x00
HID_CONTROL = 0x01
HID_GET_REPORT = 0x04
HID_SET_REPORT = 0x05
HID_GET_PROTOCOL = 0x06
HID_SET_PROTOCOL = 0x07
HID_DATA = 0x0A
# Protocol modes
class ProtocolMode(enum.IntEnum):
HID_BOOT_PROTOCOL_MODE = 0x00
HID_REPORT_PROTOCOL_MODE = 0x01
# Control Operations
class ControlCommand(enum.IntEnum):
HID_SUSPEND = 0x03
HID_EXIT_SUSPEND = 0x04
HID_VIRTUAL_CABLE_UNPLUG = 0x05
# HIDP message types
@dataclass
class GetReportMessage(Message):
report_type : int
report_id : int
buffer_size : int
'''
def __init__(self, def __init__(self,
report_type: Optional[int] = None, report_type: Optional[int] = None,
report_id: Optional[int] = None, report_id: Optional[int] = None,
buffer_size: Optional[int] = None, buffer_size: Optional[int] = None,
protocol_mode: Optional[int] = None, ):
data: Optional[bytes] = None) -> None:
self.report_type = report_type self.report_type = report_type
self.report_id = report_id self.report_id = report_id
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.protocol_mode = protocol_mode '''
self.data = data def __bytes__(self) -> bytes:
if(self.report_type == Message.ReportType.HID_OTHER_REPORT):
def to_bytes_gr(self) -> bytes:
if(self.report_type == HID_OTHER_REPORT):
param = self.report_type param = self.report_type
else: else:
param = 0x08 | self.report_type param = 0x08 | self.report_type
header = ((HID_GET_REPORT << 4) | param) header = ((Message.Type.HID_GET_REPORT << 4) | param)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header) packet_bytes.append(header)
packet_bytes.append(self.report_id) packet_bytes.append(self.report_id)
packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)]) packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)])
return bytes(packet_bytes) return bytes(packet_bytes)
def to_bytes_sr(self) -> bytes: class SetReportMessage(Message):
header = ((HID_SET_REPORT << 4) | self.report_type)
def __init__(self,
report_type: int,
data : bytes):
self.report_type = report_type
self.data = data
def __bytes__(self) -> bytes:
header = ((Message.Type.HID_SET_REPORT << 4) | self.report_type)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header) packet_bytes.append(header)
packet_bytes.extend(self.data) packet_bytes.extend(self.data)
return bytes(packet_bytes) return bytes(packet_bytes)
def to_bytes_gp(self) -> bytes: class GetProtocolMessage(Message):
header = (HID_GET_PROTOCOL << 4)
def __bytes__(self) -> bytes:
header = (Message.Type.HID_GET_PROTOCOL << 4)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header) packet_bytes.append(header)
return bytes(packet_bytes) return bytes(packet_bytes)
def to_bytes_sp(self) -> bytes: class SetProtocolMessage(Message):
header = (HID_SET_PROTOCOL << 4 | self.protocol_mode)
def __init__(self, protocol_mode: int):
self.protocol_mode = protocol_mode
def __bytes__(self) -> bytes:
header = (Message.Type.HID_SET_PROTOCOL << 4 | self.protocol_mode)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header) packet_bytes.append(header)
packet_bytes.append(self.protocol_mode) packet_bytes.append(self.protocol_mode)
return bytes(packet_bytes) return bytes(packet_bytes)
def to_bytes_send_data(self) -> bytes: class SendData(Message):
header = ((HID_DATA << 4) | HID_OUTPUT_REPORT) def __init__(self, data : bytes):
self.data = data
def __bytes__(self) -> bytes:
header = ((Message.Type.HID_DATA << 4) | Message.ReportType.HID_OUTPUT_REPORT)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header) packet_bytes.append(header)
packet_bytes.extend(self.data) packet_bytes.extend(self.data)
return bytes(packet_bytes) return bytes(packet_bytes)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HIDHost(EventEmitter): class Host(EventEmitter):
l2cap_channel: Optional[l2cap.Channel] l2cap_channel: Optional[l2cap.Channel]
def __init__(self, device: Device, connection: Connection) -> None: def __init__(self, device: Device, connection: Connection) -> None:
@@ -138,17 +173,17 @@ class HIDHost(EventEmitter):
self.l2cap_intr_channel = None self.l2cap_intr_channel = None
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection) device.register_l2cap_server(Message.HIDPsm.HID_CONTROL_PSM, self.on_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection) device.register_l2cap_server(Message.HIDPsm.HID_INTERRUPT_PSM, self.on_connection)
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
try: try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM self.connection, Message.HIDPsm.HID_CONTROL_PSM
) )
except ProtocolError as error: except ProtocolError as error:
logger.error(f'L2CAP connection failed: {error}') logging.exception(f'L2CAP connection failed: {error}')
raise raise
assert self.l2cap_ctrl_channel is not None assert self.l2cap_ctrl_channel is not None
@@ -159,10 +194,10 @@ class HIDHost(EventEmitter):
# Create a new L2CAP connection - interrupt channel # Create a new L2CAP connection - interrupt channel
try: try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM self.connection, Message.HIDPsm.HID_INTERRUPT_PSM
) )
except ProtocolError as error: except ProtocolError as error:
logger.error(f'L2CAP connection failed: {error}') logging.exception(f'L2CAP connection failed: {error}')
raise raise
assert self.l2cap_intr_channel is not None assert self.l2cap_intr_channel is not None
@@ -173,18 +208,24 @@ class HIDHost(EventEmitter):
if self.l2cap_intr_channel is None: if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
await self.l2cap_intr_channel.disconnect() # type: ignore await self.l2cap_intr_channel.disconnect() # type: ignore
channel = self.l2cap_intr_channel
self.l2cap_intr_channel = None
await channel.disconnect() # type: ignore
async def disconnect_control_channel(self) -> None: async def disconnect_control_channel(self) -> None:
if self.l2cap_ctrl_channel is None: if self.l2cap_ctrl_channel is None:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
await self.l2cap_ctrl_channel.disconnect() # type: ignore await self.l2cap_ctrl_channel.disconnect() # type: ignore
channel = self.l2cap_ctrl_channel
self.l2cap_ctrl_channel = None
await channel.disconnect() # type: ignore
def on_connection(self, l2cap_channel: l2cap.Channel) -> None: def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM: if l2cap_channel.psm == Message.HIDPsm.HID_CONTROL_PSM:
self.l2cap_ctrl_channel = l2cap_channel self.l2cap_ctrl_channel = l2cap_channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
else: else:
@@ -197,15 +238,22 @@ class HIDHost(EventEmitter):
# Here we will receive all kinds of packets, parse and then call respective callbacks # Here we will receive all kinds of packets, parse and then call respective callbacks
message_type = pdu[0] >> 4 message_type = pdu[0] >> 4
param = pdu[0] & 0x0f param = pdu[0] & 0x0f
if message_type == HID_HANDSHAKE :
for command in Message.ControlCommand.__members__items():
if param == command:
logger.debug(f'<<< ', command + pdu)
self.handle_handshake(param)
self.emit(command, pdu)
'''
if message_type == Message.Type.HID_HANDSHAKE :
logger.debug('<<< HID HANDSHAKE') logger.debug('<<< HID HANDSHAKE')
self.handle_handshake(param) self.handle_handshake(param)
self.emit('handshake', pdu) self.emit('handshake', pdu)
elif message_type == HID_DATA : elif message_type == Message.Type.HID_DATA :
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu) self.emit('data', pdu)
elif message_type == HID_CONTROL : elif message_type == Message.Type.HID_CONTROL :
if param == HID_SUSPEND : if param == Message.ControlCommand.HID_SUSPEND :
logger.debug('<<< HID SUSPEND') logger.debug('<<< HID SUSPEND')
self.emit('suspend', pdu) self.emit('suspend', pdu)
elif param == HID_EXIT_SUSPEND : elif param == HID_EXIT_SUSPEND :
@@ -219,34 +267,35 @@ class HIDHost(EventEmitter):
else: else:
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu) self.emit('data', pdu)
'''
def on_intr_pdu(self, pdu: bytes) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("data", pdu) self.emit("data", pdu)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = HIDPacket(report_type = report_type , report_id = report_id , buffer_size = buffer_size) msg = GetReportMessage(report_type = report_type , report_id = report_id , buffer_size = buffer_size)
hid_packet = msg.to_bytes_gr() hid_message = msg.__bytes__()
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_packet.hex()}') logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_packet) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def set_report(self, report_type: int, data: bytes): def set_report(self, report_type: int, data: bytes):
msg = HIDPacket(report_type= report_type,data = data) msg = SetReportMessage(report_type= report_type,data = data)
hid_packet = msg.to_bytes_sr() hid_message = msg.__bytes__()
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_packet.hex()}') logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_packet) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def get_protocol(self): def get_protocol(self):
msg = HIDPacket() msg = GetProtocolMessage()
hid_packet = msg.to_bytes_gp() hid_message = msg.__bytes__()
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_packet.hex()}') logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_packet) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def set_protocol(self, protocol_mode: int): def set_protocol(self, protocol_mode: int):
msg = HIDPacket(protocol_mode= protocol_mode) msg = SetProtocolMessage(protocol_mode= protocol_mode)
hid_packet = msg.to_bytes_sp() hid_message = msg.__bytes__()
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_packet.hex()}') logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_packet) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def send_pdu_on_ctrl(self, msg: bytes) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore
@@ -255,30 +304,34 @@ class HIDHost(EventEmitter):
self.l2cap_intr_channel.send_pdu(msg) # type: ignore self.l2cap_intr_channel.send_pdu(msg) # type: ignore
def send_data(self, data): def send_data(self, data):
msg = HIDPacket(data= data) msg = Message(data= data)
hid_packet = msg.to_bytes_send_data() hid_message = msg.__bytes__()
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_packet.hex()}') logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_packet) # type: ignore self.send_pdu_on_intr(hid_message) # type: ignore
def suspend(self): def suspend(self):
header = (HID_CONTROL << 4 | HID_SUSPEND) header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_SUSPEND)
msg = bytearray([header]) msg = bytearray([header])
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{msg.hex()}') logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{msg.hex()}')
self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore
def exit_suspend(self): def exit_suspend(self):
header = (HID_CONTROL << 4 | HID_EXIT_SUSPEND) header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_EXIT_SUSPEND)
msg = bytearray([header]) msg = bytearray([header])
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{msg.hex()}') logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{msg.hex()}')
self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore
def virtual_cable_unplug(self): def virtual_cable_unplug(self):
header = (HID_CONTROL << 4 | HID_VIRTUAL_CABLE_UNPLUG) header = (Message.Type.HID_CONTROL << 4 | Message.ControlCommand.HID_VIRTUAL_CABLE_UNPLUG)
msg = bytearray([header]) msg = bytearray([header])
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {msg.hex()}') logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {msg.hex()}')
self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore
def handle_handshake(self, param: int): def handle_handshake(self, param: Message.HandshakeState):
for state in Message.HandshakeState.__members__items():
if param == state:
logger.debug(f'<<< HID HANDSHAKE: ', state)
'''
if param == HANDSHAKE_SUCCESSFUL : if param == HANDSHAKE_SUCCESSFUL :
logger.debug(f'<<< HID HANDSHAKE: SUCCESSFUL') logger.debug(f'<<< HID HANDSHAKE: SUCCESSFUL')
elif param == HANDSHAKE_NOT_READY : elif param == HANDSHAKE_NOT_READY :
@@ -293,3 +346,4 @@ class HIDHost(EventEmitter):
logger.warning(f'<<< HID HANDSHAKE: ERR_FATAL') logger.warning(f'<<< HID HANDSHAKE: ERR_FATAL')
else: # 0x5-0xD = Reserved else: # 0x5-0xD = Reserved
logger.warning("<<< HID HANDSHAKE: RESERVED VALUE") logger.warning("<<< HID HANDSHAKE: RESERVED VALUE")
'''

View File

@@ -33,7 +33,7 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
) )
from bumble.hci import Address from bumble.hci import Address
from bumble.hid import HIDHost, HID_INPUT_REPORT, HID_OTHER_REPORT, HID_BOOT_PROTOCOL_MODE, HID_REPORT_PROTOCOL_MODE from bumble.hid import Host, Message
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client, Client as SDP_Client,
DataElement, DataElement,
@@ -243,13 +243,13 @@ async def main():
report_length = len(pdu[1:]) report_length = len(pdu[1:])
report_id = pdu[1] report_id = pdu[1]
if (report_type != HID_OTHER_REPORT): if (report_type != Message.ReportType.HID_OTHER_REPORT):
print(color(f' Report type = {report_type}, Report length = {report_length}, Report id = {report_id}', 'blue', None, 'bold')) print(color(f' Report type = {report_type}, Report length = {report_length}, Report id = {report_id}', 'blue', None, 'bold'))
if ((report_length <= 1) or (report_id == 0)): if ((report_length <= 1) or (report_id == 0)):
return return
if report_type == HID_INPUT_REPORT: if report_type == Message.ReportType.HID_INPUT_REPORT:
ReportParser.parse_input_report(pdu[1:]) #type: ignore ReportParser.parse_input_report(pdu[1:]) #type: ignore
async def handle_virtual_cable_unplug(): async def handle_virtual_cable_unplug():
@@ -290,7 +290,7 @@ async def main():
# Create HID host and start it # Create HID host and start it
print('@@@ Starting HID Host...') print('@@@ Starting HID Host...')
hid_host = HIDHost(device, connection) hid_host = Host(device, connection)
# Register for HID data call back # Register for HID data call back
hid_host.on('data', on_hid_data_cb) hid_host.on('data', on_hid_data_cb)
@@ -383,10 +383,10 @@ async def main():
choice1 = choice1.decode('utf-8').strip() choice1 = choice1.decode('utf-8').strip()
if choice1 == '0': if choice1 == '0':
hid_host.set_protocol(HID_BOOT_PROTOCOL_MODE) hid_host.set_protocol(Message.ProtocolMode.HID_BOOT_PROTOCOL_MODE)
elif choice1 == '1': elif choice1 == '1':
hid_host.set_protocol(HID_REPORT_PROTOCOL_MODE) hid_host.set_protocol(Message.ProtocolMode.HID_REPORT_PROTOCOL_MODE)
else: else:
print('Incorrect option selected') print('Incorrect option selected')