Correct HID type annotations

This commit is contained in:
Josh Wu
2024-08-13 22:52:00 +08:00
parent 03c79aacb2
commit 2248f9ae5e
2 changed files with 49 additions and 57 deletions

View File

@@ -23,13 +23,12 @@ import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING from typing import Optional, Callable
from typing_extensions import override from typing_extensions import override
from bumble import l2cap, device from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from .hci import Address from bumble.hci import Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
try: try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM self.connection, HID_CONTROL_PSM
) )
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError: except ProtocolError:
logging.exception(f'L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None: async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel # Create a new L2CAP connection - interrupt channel
try: try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM self.connection, HID_INTERRUPT_PSM
) )
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError: except ProtocolError:
logging.exception(f'L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None: async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None: if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
@@ -334,17 +329,18 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04 ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF SUCCESS = 0xFF
@dataclass
class GetSetStatus: class GetSetStatus:
def __init__(self) -> None: data: bytes = b''
self.data = bytearray() status: int = 0
self.status = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None: def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE) super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override @override
def on_ctrl_pdu(self, pdu: bytes) -> None: def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -410,7 +406,6 @@ class Device(HID):
buffer_size = 0 buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size) ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE: if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS: elif ret.status == self.GetSetReturn.SUCCESS:
@@ -428,7 +423,9 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None: def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
self.get_report_cb = cb self.get_report_cb = cb
logger.debug("GetReport callback registered successfully") logger.debug("GetReport callback registered successfully")
@@ -442,7 +439,6 @@ class Device(HID):
report_data = pdu[2:] report_data = pdu[2:]
report_size = len(report_data) + 1 report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data) ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -453,7 +449,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb( def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None] self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
) -> None: ) -> None:
self.set_report_cb = cb self.set_report_cb = cb
logger.debug("SetReport callback registered successfully") logger.debug("SetReport callback registered successfully")
@@ -464,13 +460,12 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.get_protocol_cb() ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None: def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
self.get_protocol_cb = cb self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully") logger.debug("GetProtocol callback registered successfully")
@@ -480,13 +475,14 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.set_protocol_cb(pdu[0] & 0x01) ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None: def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
self.set_protocol_cb = cb self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully") logger.debug("SetProtocol callback registered successfully")

View File

@@ -21,7 +21,7 @@ import os
import logging import logging
import json import json
import websockets import websockets
from bumble.colors import color import struct
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -30,9 +30,7 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE, BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID, BT_HIDP_PROTOCOL_ID,
UUID,
) )
from bumble.hci import Address
from bumble.hid import ( from bumble.hid import (
Device as HID_Device, Device as HID_Device,
HID_CONTROL_PSM, HID_CONTROL_PSM,
@@ -40,20 +38,17 @@ from bumble.hid import (
Message, Message,
) )
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client,
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT, SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
) )
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices # SDP attributes for Bluetooth HID devices
@@ -430,7 +425,7 @@ deviceData = DeviceData()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def keyboard_device(hid_device): async def keyboard_device(hid_device: HID_Device):
# Start a Websocket server to receive events from a web page # Start a Websocket server to receive events from a web page
async def serve(websocket, _path): async def serve(websocket, _path):
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
# limiting x and y values within logical max and min range # limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x)) x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y)) y = max(log_min, min(log_max, y))
x_cord = x.to_bytes(signed=True) deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
y_cord = y.to_bytes(signed=True) ">bb", x, y
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord )
hid_device.send_data(deviceData.mouseData) hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
pass pass
@@ -515,7 +510,9 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes): def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}') print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int): def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus() retValue = hid_device.GetSetStatus()
print( print(
"GET_REPORT report_id: " "GET_REPORT report_id: "
@@ -555,8 +552,7 @@ async def main() -> None:
def on_set_report_cb( def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes report_id: int, report_type: int, report_size: int, data: bytes
): ) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
print( print(
"SET_REPORT report_id: " "SET_REPORT report_id: "
+ str(report_id) + str(report_id)
@@ -568,33 +564,33 @@ async def main() -> None:
+ str(data) + str(data)
) )
if report_type == Message.ReportType.FEATURE_REPORT: if report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT: elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData): if report_id == 1 and report_size != len(deviceData.keyboardData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData): elif report_id == 2 and report_size != len(deviceData.mouseData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3: elif report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
else: else:
retValue.status = hid_device.GetSetReturn.SUCCESS status = HID_Device.GetSetReturn.SUCCESS
else: else:
retValue.status = hid_device.GetSetReturn.SUCCESS status = HID_Device.GetSetReturn.SUCCESS
return retValue return HID_Device.GetSetStatus(status=status)
def on_get_protocol_cb(): def on_get_protocol_cb() -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus() return HID_Device.GetSetStatus(
retValue.data = protocol_mode.to_bytes() data=bytes([protocol_mode]),
retValue.status = hid_device.GetSetReturn.SUCCESS status=hid_device.GetSetReturn.SUCCESS,
return retValue )
def on_set_protocol_cb(protocol: int): def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL. # We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}") print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST return HID_Device.GetSetStatus(
return retValue status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
def on_virtual_cable_unplug_cb(): def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug') print('Received Virtual Cable Unplug')