Review comment fix HID device

This commit is contained in:
skarnataki
2023-11-28 13:42:25 +00:00
parent ad0f035df5
commit 403a13e4c6

View File

@@ -19,17 +19,15 @@ from __future__ import annotations
from dataclasses import dataclass
import logging
import enum
import struct
from pyee import EventEmitter
from typing import Optional, TYPE_CHECKING
from bumble import l2cap
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
if TYPE_CHECKING:
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
@@ -105,10 +103,11 @@ class GetReportMessage(Message):
if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes
else:
packet_bytes.extend(
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
return self.header(0x08 | self.report_type) + packet_bytes
@dataclass
@@ -128,10 +127,7 @@ class SendControlData(Message):
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.extend(self.data)
return self.header(self.report_type) + packet_bytes
return self.header(self.report_type) + self.data
@dataclass
@@ -151,17 +147,6 @@ class SetProtocolMessage(Message):
return self.header(self.protocol_mode)
@dataclass
class GetProtocolReplyMessage(Message):
protocol_mode: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.protocol_mode)
return self.header(Message.ReportType.OTHER_REPORT) + packet_bytes
@dataclass
class Suspend(Message):
message_type = Message.MessageType.CONTROL
@@ -215,7 +200,7 @@ class HID(EventEmitter):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: Device, role: int) -> None:
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.device = device
self.connection = None
@@ -273,11 +258,11 @@ class HID(EventEmitter):
self.l2cap_ctrl_channel = None
await channel.disconnect()
def on_device_connection(self, connection: Connection) -> None:
def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection # type: ignore[assignment]
self.remote_device_bd_address = (
connection.peer_address
) # type: ignore[assignment]
connection.peer_address # type: ignore[assignment]
)
connection.on('disconnection', self.on_disconnection)
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
@@ -341,9 +326,9 @@ class HID(EventEmitter):
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(
self.send_handshake_message( # type: ignore[attr-defined]
Message.Handshake.ERR_UNSUPPORTED_REQUEST
) # type: ignore[attr-defined]
)
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
@@ -378,7 +363,7 @@ class HID(EventEmitter):
# -----------------------------------------------------------------------------
class Device(HID): # type: ignore[no-redef]
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
@@ -389,10 +374,10 @@ class Device(HID): # type: ignore[no-redef]
class GetSetStatus:
def __init__(self) -> None:
self.data: bytes
self.data = bytearray()
self.status = 0
def __init__(self, device: Device) -> None:
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
self.get_report_cb = None
self.set_report_cb = None
@@ -412,7 +397,10 @@ class Device(HID): # type: ignore[no-redef]
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
ret = self.GetSetStatus()
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
@@ -422,32 +410,23 @@ class Device(HID): # type: ignore[no-redef]
else:
buffer_size = 0
if self.get_report_cb != None:
ret = self.get_report_cb(
report_id, report_type, buffer_size
) # type: ignore
ret = self.get_report_cb(report_id, report_type, buffer_size)
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if (
len(data) < self.l2cap_ctrl_channel.mtu
): # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.mtu:
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
else:
logger.debug("GetReport callback not registered !!")
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb):
@@ -455,56 +434,55 @@ class Device(HID): # type: ignore[no-redef]
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb != None:
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(pdu[1:])
ret = self.set_report_cb(
report_id, report_type, report_size, report_data
) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
else:
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(pdu[1:])
ret = self.set_report_cb(
report_id, report_type, report_size, report_data
) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(self, cb):
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
ret = self.GetSetStatus()
if self.get_protocol_cb != None:
ret = self.get_protocol_cb() # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
return
else:
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb):
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
ret = self.GetSetStatus()
if self.set_protocol_cb != None:
ret = self.set_protocol_cb(pdu[0] & 0x01) # type: ignore
if ret.status == self.GetSetReturn.SUCCESS:
return
else:
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb):
self.set_protocol_cb = cb
@@ -513,7 +491,7 @@ class Device(HID): # type: ignore[no-redef]
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: Device) -> None:
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: