From 249a205d8e7da84d0f28e5b31491c0d26cff7bc4 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 30 Aug 2023 00:48:01 +0800 Subject: [PATCH] Typing packet transmission flow --- bumble/device.py | 8 ++++---- bumble/hci.py | 25 +++++++++++++++---------- bumble/host.py | 45 ++++++++++++++++++++++++--------------------- bumble/l2cap.py | 36 +++++++++++++++++++++--------------- 4 files changed, 64 insertions(+), 50 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index f27a780..46ce012 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -652,7 +652,7 @@ class Connection(CompositeEventEmitter): def is_incomplete(self) -> bool: return self.handle is None - def send_l2cap_pdu(self, cid, pdu): + def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None: self.device.send_l2cap_pdu(self.handle, cid, pdu) def create_l2cap_connector(self, psm): @@ -1096,7 +1096,7 @@ class Device(CompositeEventEmitter): return self._host @host.setter - def host(self, host): + def host(self, host: Host) -> None: # Unsubscribe from events from the current host if self._host: for event_name in device_host_event_handlers: @@ -1183,7 +1183,7 @@ class Device(CompositeEventEmitter): connection, psm, max_credits, mtu, mps ) - def send_l2cap_pdu(self, connection_handle, cid, pdu): + def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: self.host.send_l2cap_pdu(connection_handle, cid, pdu) async def send_command(self, command, check_result=False): @@ -3167,7 +3167,7 @@ class Device(CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_l2cap_pdu(self, connection, cid, pdu): + def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes): self.l2cap_channel_manager.on_pdu(connection, cid, pdu) def __str__(self): diff --git a/bumble/hci.py b/bumble/hci.py index 0dbb127..6bf97d9 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -20,7 +20,7 @@ import struct import collections import logging import functools -from typing import Dict, Type, Union +from typing import Dict, Type, Union, Callable, Any, Optional from .colors import color from .core import ( @@ -1918,7 +1918,7 @@ class HCI_Packet: hci_packet_type: int @staticmethod - def from_bytes(packet): + def from_bytes(packet: bytes) -> HCI_Packet: packet_type = packet[0] if packet_type == HCI_COMMAND_PACKET: @@ -1992,7 +1992,7 @@ class HCI_Command(HCI_Packet): return inner @staticmethod - def from_bytes(packet): + def from_bytes(packet: bytes) -> HCI_Command: op_code, length = struct.unpack_from(' HCI_Event: event_code = packet[1] length = packet[2] parameters = packet[3:] if len(parameters) != length: raise ValueError('invalid packet length') + cls: Type[HCI_Event | HCI_LE_Meta_Event] | None if event_code == HCI_LE_META_EVENT: # We do this dispatch here and not in the subclass in order to avoid call # loops @@ -4373,7 +4374,7 @@ class HCI_Event(HCI_Packet): return HCI_Event(event_code, parameters) # Invoke the factory to create a new instance - return cls.from_parameters(parameters) + return cls.from_parameters(parameters) # type: ignore @classmethod def from_parameters(cls, parameters): @@ -5086,6 +5087,7 @@ class HCI_Command_Complete_Event(HCI_Event): ''' return_parameters = b'' + command_opcode: int def map_return_parameters(self, return_parameters): '''Map simple 'status' return parameters to their named constant form''' @@ -5605,7 +5607,7 @@ class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event): # ----------------------------------------------------------------------------- -class HCI_AclDataPacket: +class HCI_AclDataPacket(HCI_Packet): ''' See Bluetooth spec @ 5.4.2 HCI ACL Data Packets ''' @@ -5613,7 +5615,7 @@ class HCI_AclDataPacket: hci_packet_type = HCI_ACL_DATA_PACKET @staticmethod - def from_bytes(packet): + def from_bytes(packet: bytes) -> HCI_AclDataPacket: # Read the header h, data_total_length = struct.unpack_from(' None: self.callback = callback self.current_data = None self.l2cap_pdu_length = 0 - def feed_packet(self, packet): + def feed_packet(self, packet: HCI_AclDataPacket) -> None: if packet.pb_flag in ( HCI_ACL_PB_FIRST_NON_FLUSHABLE, HCI_ACL_PB_FIRST_FLUSHABLE, @@ -5674,6 +5678,7 @@ class HCI_AclDataPacketAssembler: return self.current_data += packet.data + assert self.current_data is not None if len(self.current_data) == self.l2cap_pdu_length + 4: # The packet is complete, invoke the callback logger.debug(f'<<< ACL PDU: {self.current_data.hex()}') diff --git a/bumble/host.py b/bumble/host.py index fc4082a..288b1b6 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -15,6 +15,7 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import collections import logging @@ -30,8 +31,8 @@ from bumble import drivers from .hci import ( Address, HCI_ACL_DATA_PACKET, - HCI_COMMAND_COMPLETE_EVENT, HCI_COMMAND_PACKET, + HCI_COMMAND_COMPLETE_EVENT, HCI_EVENT_PACKET, HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, @@ -45,8 +46,11 @@ from .hci import ( HCI_VERSION_BLUETOOTH_CORE_4_0, HCI_AclDataPacket, HCI_AclDataPacketAssembler, + HCI_Command, + HCI_Command_Complete_Event, HCI_Constant, HCI_Error, + HCI_Event, HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, HCI_LE_Long_Term_Key_Request_Reply_Command, HCI_LE_Read_Buffer_Size_Command, @@ -95,17 +99,17 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 # ----------------------------------------------------------------------------- class Connection: - def __init__(self, host, handle, peer_address, transport): + def __init__(self, host: Host, handle: int, peer_address: Address, transport: int): self.host = host self.handle = handle self.peer_address = peer_address self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.transport = transport - def on_hci_acl_data_packet(self, packet): + def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: self.assembler.feed_packet(packet) - def on_acl_pdu(self, pdu): + def on_acl_pdu(self, pdu: bytes) -> None: l2cap_pdu = L2CAP_PDU.from_bytes(pdu) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) @@ -307,7 +311,7 @@ class Host(AbortableEventEmitter): def set_packet_sink(self, sink): self.hci_sink = sink - def send_hci_packet(self, packet): + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) @@ -356,13 +360,13 @@ class Host(AbortableEventEmitter): self.pending_response = None # Use this method to send a command from a task - def send_command_sync(self, command): - async def send_command(command): + def send_command_sync(self, command: HCI_Command) -> None: + async def send_command(command: HCI_Command) -> None: await self.send_command(command) asyncio.create_task(send_command(command)) - def send_l2cap_pdu(self, connection_handle, cid, pdu): + def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) # Send the data to the controller via ACL packets @@ -387,7 +391,7 @@ class Host(AbortableEventEmitter): offset += data_total_length bytes_remaining -= data_total_length - def queue_acl_packet(self, acl_packet): + def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None: self.acl_packet_queue.appendleft(acl_packet) self.check_acl_packet_queue() @@ -397,7 +401,7 @@ class Host(AbortableEventEmitter): f'{len(self.acl_packet_queue)} in queue' ) - def check_acl_packet_queue(self): + def check_acl_packet_queue(self) -> None: # Send all we can (TODO: support different LE/Classic limits) while ( len(self.acl_packet_queue) > 0 @@ -443,11 +447,10 @@ class Host(AbortableEventEmitter): ] # Packet Sink protocol (packets coming from the controller via HCI) - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: hci_packet = HCI_Packet.from_bytes(packet) if self.ready or ( - hci_packet.hci_packet_type == HCI_EVENT_PACKET - and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT + isinstance(hci_packet, HCI_Command_Complete_Event) and hci_packet.command_opcode == HCI_RESET_COMMAND ): self.on_hci_packet(hci_packet) @@ -461,36 +464,36 @@ class Host(AbortableEventEmitter): self.emit('flush') - def on_hci_packet(self, packet): + def on_hci_packet(self, packet: HCI_Packet) -> None: logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) # If the packet is a command, invoke the handler for this packet - if packet.hci_packet_type == HCI_COMMAND_PACKET: + if isinstance(packet, HCI_Command): self.on_hci_command_packet(packet) - elif packet.hci_packet_type == HCI_EVENT_PACKET: + elif isinstance(packet, HCI_Event): self.on_hci_event_packet(packet) - elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: + elif isinstance(packet, HCI_AclDataPacket): self.on_hci_acl_data_packet(packet) else: logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') - def on_hci_command_packet(self, command): + def on_hci_command_packet(self, command: HCI_Command) -> None: logger.warning(f'!!! unexpected command packet: {command}') - def on_hci_event_packet(self, event): + def on_hci_event_packet(self, event: HCI_Event) -> None: handler_name = f'on_{event.name.lower()}' handler = getattr(self, handler_name, self.on_hci_event) handler(event) - def on_hci_acl_data_packet(self, packet): + def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: # Look for the connection to which this data belongs if connection := self.connections.get(packet.connection_handle): connection.on_hci_acl_data_packet(packet) - def on_l2cap_pdu(self, connection, cid, pdu): + def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: self.emit('l2cap_pdu', connection.handle, cid, pdu) def on_command_processed(self, event): diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 4464afc..b83432a 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -33,6 +33,7 @@ from typing import ( Union, Deque, Iterable, + SupportsBytes, TYPE_CHECKING, ) @@ -47,6 +48,7 @@ from .hci import ( if TYPE_CHECKING: from bumble.device import Connection + from bumble.host import Host # ----------------------------------------------------------------------------- # Logging @@ -728,7 +730,7 @@ class Channel(EventEmitter): def __init__( self, - manager: 'ChannelManager', + manager: ChannelManager, connection: Connection, signaling_cid: int, psm: int, @@ -755,13 +757,13 @@ class Channel(EventEmitter): ) self.state = new_state - def send_pdu(self, pdu) -> None: + def send_pdu(self, pdu: SupportsBytes | bytes) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) - def send_control_frame(self, frame) -> None: + def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: self.manager.send_control_frame(self.connection, self.signaling_cid, frame) - async def send_request(self, request) -> bytes: + async def send_request(self, request: SupportsBytes) -> bytes: # Check that there isn't already a request pending if self.response: raise InvalidStateError('request already pending') @@ -772,7 +774,7 @@ class Channel(EventEmitter): self.send_pdu(request) return await self.response - def on_pdu(self, pdu) -> None: + def on_pdu(self, pdu: bytes) -> None: if self.response: self.response.set_result(pdu) self.response = None @@ -1041,7 +1043,7 @@ class LeConnectionOrientedChannel(EventEmitter): def __init__( self, - manager: 'ChannelManager', + manager: ChannelManager, connection: Connection, le_psm: int, source_cid: int, @@ -1096,10 +1098,10 @@ class LeConnectionOrientedChannel(EventEmitter): elif new_state == self.DISCONNECTED: self.emit('close') - def send_pdu(self, pdu) -> None: + def send_pdu(self, pdu: SupportsBytes | bytes) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) - def send_control_frame(self, frame) -> None: + def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) async def connect(self) -> LeConnectionOrientedChannel: @@ -1154,7 +1156,7 @@ class LeConnectionOrientedChannel(EventEmitter): if self.state == self.CONNECTED: self.change_state(self.DISCONNECTED) - def on_pdu(self, pdu) -> None: + def on_pdu(self, pdu: bytes) -> None: if self.sink is None: logger.warning('received pdu without a sink') return @@ -1384,6 +1386,7 @@ class ChannelManager: ] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] + _host: Optional[Host] def __init__( self, @@ -1407,11 +1410,12 @@ class ChannelManager: self.connectionless_mtu = connectionless_mtu @property - def host(self): + def host(self) -> Host: + assert self._host return self._host @host.setter - def host(self, host): + def host(self, host: Host) -> None: if self._host is not None: self._host.remove_listener('disconnection', self.on_disconnection) self._host = host @@ -1565,7 +1569,7 @@ class ChannelManager: if connection_handle in self.identifiers: del self.identifiers[connection_handle] - def send_pdu(self, connection, cid: int, pdu) -> None: + def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None: pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) logger.debug( f'{color(">>> Sending L2CAP PDU", "blue")} ' @@ -1574,7 +1578,7 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) - def on_pdu(self, connection: Connection, cid: int, pdu) -> None: + def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): # Parse the L2CAP payload into a Control Frame object control_frame = L2CAP_Control_Frame.from_bytes(pdu) @@ -1596,7 +1600,7 @@ class ChannelManager: channel.on_pdu(pdu) def send_control_frame( - self, connection: Connection, cid: int, control_frame + self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame ) -> None: logger.debug( f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' @@ -1605,7 +1609,9 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) - def on_control_frame(self, connection: Connection, cid: int, control_frame) -> None: + def on_control_frame( + self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame + ) -> None: logger.debug( f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) '