forked from auracaster/bumble_mirror
Correct HID type annotations
This commit is contained in:
@@ -23,13 +23,12 @@ import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
from typing import Optional, Callable
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap, device
|
||||
from bumble.colors import color
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
from .hci import Address
|
||||
from bumble.hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
try:
|
||||
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_CONTROL_PSM
|
||||
)
|
||||
channel.sink = self.on_ctrl_pdu
|
||||
self.l2cap_ctrl_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_ctrl_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
|
||||
async def connect_interrupt_channel(self) -> None:
|
||||
# Create a new L2CAP connection - interrupt channel
|
||||
try:
|
||||
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_INTERRUPT_PSM
|
||||
)
|
||||
channel.sink = self.on_intr_pdu
|
||||
self.l2cap_intr_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_intr_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
|
||||
async def disconnect_interrupt_channel(self) -> None:
|
||||
if self.l2cap_intr_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
@@ -334,17 +329,18 @@ class Device(HID):
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
@dataclass
|
||||
class GetSetStatus:
|
||||
def __init__(self) -> None:
|
||||
self.data = bytearray()
|
||||
self.status = 0
|
||||
data: bytes = b''
|
||||
status: int = 0
|
||||
|
||||
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
|
||||
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
@@ -410,7 +406,6 @@ class Device(HID):
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
@@ -428,7 +423,9 @@ class Device(HID):
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
||||
def register_get_report_cb(
|
||||
self, cb: Callable[[int, int, int], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
@@ -442,7 +439,6 @@ class Device(HID):
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
@@ -453,7 +449,7 @@ class Device(HID):
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], None]
|
||||
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
@@ -464,13 +460,12 @@ class Device(HID):
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
||||
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
@@ -480,13 +475,14 @@ class Device(HID):
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
||||
def register_set_protocol_cb(
|
||||
self, cb: Callable[[int], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user