Correct HID type annotations

This commit is contained in:
Josh Wu
2024-08-13 22:52:00 +08:00
parent 03c79aacb2
commit 2248f9ae5e
2 changed files with 49 additions and 57 deletions

View File

@@ -23,13 +23,12 @@ import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING
from typing import Optional, Callable
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
from bumble.hci import Address
# -----------------------------------------------------------------------------
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
@@ -334,17 +329,18 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
@dataclass
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -410,7 +406,6 @@ class Device(HID):
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
@@ -428,7 +423,9 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
@@ -442,7 +439,6 @@ class Device(HID):
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -453,7 +449,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
@@ -464,13 +460,12 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
@@ -480,13 +475,14 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")

View File

@@ -21,7 +21,7 @@ import os
import logging
import json
import websockets
from bumble.colors import color
import struct
from bumble.device import Device
from bumble.transport import open_transport_or_link
@@ -30,9 +30,7 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID,
UUID,
)
from bumble.hci import Address
from bumble.hid import (
Device as HID_Device,
HID_CONTROL_PSM,
@@ -40,20 +38,17 @@ from bumble.hid import (
Message,
)
from bumble.sdp import (
Client as SDP_Client,
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices
@@ -430,7 +425,7 @@ deviceData = DeviceData()
# -----------------------------------------------------------------------------
async def keyboard_device(hid_device):
async def keyboard_device(hid_device: HID_Device):
# Start a Websocket server to receive events from a web page
async def serve(websocket, _path):
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
# limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y))
x_cord = x.to_bytes(signed=True)
y_cord = y.to_bytes(signed=True)
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
">bb", x, y
)
hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK:
pass
@@ -515,7 +510,9 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
print(
"GET_REPORT report_id: "
@@ -555,8 +552,7 @@ async def main() -> None:
def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes
):
retValue = hid_device.GetSetStatus()
) -> HID_Device.GetSetStatus:
print(
"SET_REPORT report_id: "
+ str(report_id)
@@ -568,33 +564,33 @@ async def main() -> None:
+ str(data)
)
if report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
status = HID_Device.GetSetReturn.SUCCESS
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
status = HID_Device.GetSetReturn.SUCCESS
return retValue
return HID_Device.GetSetStatus(status=status)
def on_get_protocol_cb():
retValue = hid_device.GetSetStatus()
retValue.data = protocol_mode.to_bytes()
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
return HID_Device.GetSetStatus(
data=bytes([protocol_mode]),
status=hid_device.GetSetReturn.SUCCESS,
)
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
return HID_Device.GetSetStatus(
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')