Review comment fix HID device

This commit is contained in:
skarnataki
2023-11-28 13:42:25 +00:00
parent ad0f035df5
commit 403a13e4c6

View File

@@ -19,17 +19,15 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import enum import enum
import struct
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
from bumble import l2cap from bumble import l2cap, device
from bumble.colors import color from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -105,10 +103,11 @@ class GetReportMessage(Message):
if self.buffer_size == 0: if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes return self.header(self.report_type) + packet_bytes
else: else:
packet_bytes.extend( return (
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)] self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
) )
return self.header(0x08 | self.report_type) + packet_bytes
@dataclass @dataclass
@@ -128,10 +127,7 @@ class SendControlData(Message):
message_type = Message.MessageType.DATA message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
packet_bytes = bytearray() return self.header(self.report_type) + self.data
packet_bytes.extend(self.data)
return self.header(self.report_type) + packet_bytes
@dataclass @dataclass
@@ -151,17 +147,6 @@ class SetProtocolMessage(Message):
return self.header(self.protocol_mode) return self.header(self.protocol_mode)
@dataclass
class GetProtocolReplyMessage(Message):
protocol_mode: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.protocol_mode)
return self.header(Message.ReportType.OTHER_REPORT) + packet_bytes
@dataclass @dataclass
class Suspend(Message): class Suspend(Message):
message_type = Message.MessageType.CONTROL message_type = Message.MessageType.CONTROL
@@ -215,7 +200,7 @@ class HID(EventEmitter):
HOST = 0x00 HOST = 0x00
DEVICE = 0x01 DEVICE = 0x01
def __init__(self, device: Device, role: int) -> None: def __init__(self, device: device.Device, role: Role) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.connection = None self.connection = None
@@ -273,11 +258,11 @@ class HID(EventEmitter):
self.l2cap_ctrl_channel = None self.l2cap_ctrl_channel = None
await channel.disconnect() await channel.disconnect()
def on_device_connection(self, connection: Connection) -> None: def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection # type: ignore[assignment] self.connection = connection # type: ignore[assignment]
self.remote_device_bd_address = ( self.remote_device_bd_address = (
connection.peer_address connection.peer_address # type: ignore[assignment]
) # type: ignore[assignment] )
connection.on('disconnection', self.on_disconnection) connection.on('disconnection', self.on_disconnection)
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
@@ -341,9 +326,9 @@ class HID(EventEmitter):
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else: else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message( self.send_handshake_message( # type: ignore[attr-defined]
Message.Handshake.ERR_UNSUPPORTED_REQUEST Message.Handshake.ERR_UNSUPPORTED_REQUEST
) # type: ignore[attr-defined] )
def on_intr_pdu(self, pdu: bytes) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
@@ -378,7 +363,7 @@ class HID(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Device(HID): # type: ignore[no-redef] class Device(HID):
class GetSetReturn(enum.IntEnum): class GetSetReturn(enum.IntEnum):
FAILURE = 0x00 FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01 REPORT_ID_NOT_FOUND = 0x01
@@ -389,10 +374,10 @@ class Device(HID): # type: ignore[no-redef]
class GetSetStatus: class GetSetStatus:
def __init__(self) -> None: def __init__(self) -> None:
self.data: bytes self.data = bytearray()
self.status = 0 self.status = 0
def __init__(self, device: Device) -> None: def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE) super().__init__(device, HID.Role.DEVICE)
self.get_report_cb = None self.get_report_cb = None
self.set_report_cb = None self.set_report_cb = None
@@ -412,7 +397,10 @@ class Device(HID): # type: ignore[no-redef]
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes): def handle_get_report(self, pdu: bytes):
ret = self.GetSetStatus() if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03 report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3 buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1] report_id = pdu[1]
@@ -422,32 +410,23 @@ class Device(HID): # type: ignore[no-redef]
else: else:
buffer_size = 0 buffer_size = 0
if self.get_report_cb != None: ret = self.get_report_cb(report_id, report_type, buffer_size)
ret = self.get_report_cb(
report_id, report_type, buffer_size
) # type: ignore
if ret.status == self.GetSetReturn.FAILURE: if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS: elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray() data = bytearray()
data.append(report_id) data.append(report_id)
data.extend(ret.data) data.extend(ret.data)
if ( if len(data) < self.l2cap_ctrl_channel.mtu:
len(data) < self.l2cap_ctrl_channel.mtu self.send_control_data(report_type=report_type, data=data)
): # type: ignore[union-attr] else:
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else: elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
logger.debug("GetReport callback not registered !!") self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb): def register_get_report_cb(self, cb):
@@ -455,56 +434,55 @@ class Device(HID): # type: ignore[no-redef]
logger.debug("GetReport callback registered successfully") logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes): def handle_set_report(self, pdu: bytes):
if self.set_report_cb != None: if self.set_report_cb is None:
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(pdu[1:])
ret = self.set_report_cb(
report_id, report_type, report_size, report_data
) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
else:
logger.debug("SetReport callback not registered !!") logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(pdu[1:])
ret = self.set_report_cb(
report_id, report_type, report_size, report_data
) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(self, cb): def register_set_report_cb(self, cb):
self.set_report_cb = cb self.set_report_cb = cb
logger.debug("SetReport callback registered successfully") logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes): def handle_get_protocol(self, pdu: bytes):
ret = self.GetSetStatus() if self.get_protocol_cb is None:
if self.get_protocol_cb != None:
ret = self.get_protocol_cb() # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
return
else:
logger.debug("GetProtocol callback not registered !!") logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) return
ret = self.get_protocol_cb()
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb): def register_get_protocol_cb(self, cb):
self.get_protocol_cb = cb self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully") logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes): def handle_set_protocol(self, pdu: bytes):
ret = self.GetSetStatus() if self.set_protocol_cb is None:
if self.set_protocol_cb != None:
ret = self.set_protocol_cb(pdu[0] & 0x01) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
return
else:
logger.debug("SetProtocol callback not registered !!") logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) return
ret = self.set_protocol_cb(pdu[0] & 0x01)
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb): def register_set_protocol_cb(self, cb):
self.set_protocol_cb = cb self.set_protocol_cb = cb
@@ -513,7 +491,7 @@ class Device(HID): # type: ignore[no-redef]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(HID): class Host(HID):
def __init__(self, device: Device) -> None: def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST) super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: