Typing packet transmission flow

This commit is contained in:
Josh Wu
2023-08-30 00:48:01 +08:00
parent 7485801222
commit 249a205d8e
4 changed files with 64 additions and 50 deletions

View File

@@ -652,7 +652,7 @@ class Connection(CompositeEventEmitter):
def is_incomplete(self) -> bool: def is_incomplete(self) -> bool:
return self.handle is None return self.handle is None
def send_l2cap_pdu(self, cid, pdu): def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(self.handle, cid, pdu) self.device.send_l2cap_pdu(self.handle, cid, pdu)
def create_l2cap_connector(self, psm): def create_l2cap_connector(self, psm):
@@ -1096,7 +1096,7 @@ class Device(CompositeEventEmitter):
return self._host return self._host
@host.setter @host.setter
def host(self, host): def host(self, host: Host) -> None:
# Unsubscribe from events from the current host # Unsubscribe from events from the current host
if self._host: if self._host:
for event_name in device_host_event_handlers: for event_name in device_host_event_handlers:
@@ -1183,7 +1183,7 @@ class Device(CompositeEventEmitter):
connection, psm, max_credits, mtu, mps connection, psm, max_credits, mtu, mps
) )
def send_l2cap_pdu(self, connection_handle, cid, pdu): def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu) self.host.send_l2cap_pdu(connection_handle, cid, pdu)
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):
@@ -3167,7 +3167,7 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes):
self.l2cap_channel_manager.on_pdu(connection, cid, pdu) self.l2cap_channel_manager.on_pdu(connection, cid, pdu)
def __str__(self): def __str__(self):

View File

@@ -20,7 +20,7 @@ import struct
import collections import collections
import logging import logging
import functools import functools
from typing import Dict, Type, Union from typing import Dict, Type, Union, Callable, Any, Optional
from .colors import color from .colors import color
from .core import ( from .core import (
@@ -1918,7 +1918,7 @@ class HCI_Packet:
hci_packet_type: int hci_packet_type: int
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_Packet:
packet_type = packet[0] packet_type = packet[0]
if packet_type == HCI_COMMAND_PACKET: if packet_type == HCI_COMMAND_PACKET:
@@ -1992,7 +1992,7 @@ class HCI_Command(HCI_Packet):
return inner return inner
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_Command:
op_code, length = struct.unpack_from('<HB', packet, 1) op_code, length = struct.unpack_from('<HB', packet, 1)
parameters = packet[4:] parameters = packet[4:]
if len(parameters) != length: if len(parameters) != length:
@@ -2011,7 +2011,7 @@ class HCI_Command(HCI_Packet):
HCI_Object.init_from_bytes(self, parameters, 0, fields) HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self return self
return cls.from_parameters(parameters) return cls.from_parameters(parameters) # type: ignore
@staticmethod @staticmethod
def command_name(op_code): def command_name(op_code):
@@ -4350,13 +4350,14 @@ class HCI_Event(HCI_Packet):
return event_class return event_class
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_Event:
event_code = packet[1] event_code = packet[1]
length = packet[2] length = packet[2]
parameters = packet[3:] parameters = packet[3:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise ValueError('invalid packet length')
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call # We do this dispatch here and not in the subclass in order to avoid call
# loops # loops
@@ -4373,7 +4374,7 @@ class HCI_Event(HCI_Packet):
return HCI_Event(event_code, parameters) return HCI_Event(event_code, parameters)
# Invoke the factory to create a new instance # Invoke the factory to create a new instance
return cls.from_parameters(parameters) return cls.from_parameters(parameters) # type: ignore
@classmethod @classmethod
def from_parameters(cls, parameters): def from_parameters(cls, parameters):
@@ -5086,6 +5087,7 @@ class HCI_Command_Complete_Event(HCI_Event):
''' '''
return_parameters = b'' return_parameters = b''
command_opcode: int
def map_return_parameters(self, return_parameters): def map_return_parameters(self, return_parameters):
'''Map simple 'status' return parameters to their named constant form''' '''Map simple 'status' return parameters to their named constant form'''
@@ -5605,7 +5607,7 @@ class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_AclDataPacket: class HCI_AclDataPacket(HCI_Packet):
''' '''
See Bluetooth spec @ 5.4.2 HCI ACL Data Packets See Bluetooth spec @ 5.4.2 HCI ACL Data Packets
''' '''
@@ -5613,7 +5615,7 @@ class HCI_AclDataPacket:
hci_packet_type = HCI_ACL_DATA_PACKET hci_packet_type = HCI_ACL_DATA_PACKET
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_AclDataPacket:
# Read the header # Read the header
h, data_total_length = struct.unpack_from('<HH', packet, 1) h, data_total_length = struct.unpack_from('<HH', packet, 1)
connection_handle = h & 0xFFF connection_handle = h & 0xFFF
@@ -5655,12 +5657,14 @@ class HCI_AclDataPacket:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_AclDataPacketAssembler: class HCI_AclDataPacketAssembler:
def __init__(self, callback): current_data: Optional[bytes]
def __init__(self, callback: Callable[[bytes], Any]) -> None:
self.callback = callback self.callback = callback
self.current_data = None self.current_data = None
self.l2cap_pdu_length = 0 self.l2cap_pdu_length = 0
def feed_packet(self, packet): def feed_packet(self, packet: HCI_AclDataPacket) -> None:
if packet.pb_flag in ( if packet.pb_flag in (
HCI_ACL_PB_FIRST_NON_FLUSHABLE, HCI_ACL_PB_FIRST_NON_FLUSHABLE,
HCI_ACL_PB_FIRST_FLUSHABLE, HCI_ACL_PB_FIRST_FLUSHABLE,
@@ -5674,6 +5678,7 @@ class HCI_AclDataPacketAssembler:
return return
self.current_data += packet.data self.current_data += packet.data
assert self.current_data is not None
if len(self.current_data) == self.l2cap_pdu_length + 4: if len(self.current_data) == self.l2cap_pdu_length + 4:
# The packet is complete, invoke the callback # The packet is complete, invoke the callback
logger.debug(f'<<< ACL PDU: {self.current_data.hex()}') logger.debug(f'<<< ACL PDU: {self.current_data.hex()}')

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import collections import collections
import logging import logging
@@ -30,8 +31,8 @@ from bumble import drivers
from .hci import ( from .hci import (
Address, Address,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET, HCI_COMMAND_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
@@ -45,8 +46,11 @@ from .hci import (
HCI_VERSION_BLUETOOTH_CORE_4_0, HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket, HCI_AclDataPacket,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Constant, HCI_Constant,
HCI_Error, HCI_Error,
HCI_Event,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command, HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command, HCI_LE_Read_Buffer_Size_Command,
@@ -95,17 +99,17 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
def __init__(self, host, handle, peer_address, transport): def __init__(self, host: Host, handle: int, peer_address: Address, transport: int):
self.host = host self.host = host
self.handle = handle self.handle = handle
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
@@ -307,7 +311,7 @@ class Host(AbortableEventEmitter):
def set_packet_sink(self, sink): def set_packet_sink(self, sink):
self.hci_sink = sink self.hci_sink = sink
def send_hci_packet(self, packet): def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
@@ -356,13 +360,13 @@ class Host(AbortableEventEmitter):
self.pending_response = None self.pending_response = None
# Use this method to send a command from a task # Use this method to send a command from a task
def send_command_sync(self, command): def send_command_sync(self, command: HCI_Command) -> None:
async def send_command(command): async def send_command(command: HCI_Command) -> None:
await self.send_command(command) await self.send_command(command)
asyncio.create_task(send_command(command)) asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle, cid, pdu): def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets # Send the data to the controller via ACL packets
@@ -387,7 +391,7 @@ class Host(AbortableEventEmitter):
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet): def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet) self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue() self.check_acl_packet_queue()
@@ -397,7 +401,7 @@ class Host(AbortableEventEmitter):
f'{len(self.acl_packet_queue)} in queue' f'{len(self.acl_packet_queue)} in queue'
) )
def check_acl_packet_queue(self): def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits) # Send all we can (TODO: support different LE/Classic limits)
while ( while (
len(self.acl_packet_queue) > 0 len(self.acl_packet_queue) > 0
@@ -443,11 +447,10 @@ class Host(AbortableEventEmitter):
] ]
# Packet Sink protocol (packets coming from the controller via HCI) # Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
hci_packet = HCI_Packet.from_bytes(packet) hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or ( if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET isinstance(hci_packet, HCI_Command_Complete_Event)
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
and hci_packet.command_opcode == HCI_RESET_COMMAND and hci_packet.command_opcode == HCI_RESET_COMMAND
): ):
self.on_hci_packet(hci_packet) self.on_hci_packet(hci_packet)
@@ -461,36 +464,36 @@ class Host(AbortableEventEmitter):
self.emit('flush') self.emit('flush')
def on_hci_packet(self, packet): def on_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET: if isinstance(packet, HCI_Command):
self.on_hci_command_packet(packet) self.on_hci_command_packet(packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif isinstance(packet, HCI_Event):
self.on_hci_event_packet(packet) self.on_hci_event_packet(packet)
elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: elif isinstance(packet, HCI_AclDataPacket):
self.on_hci_acl_data_packet(packet) self.on_hci_acl_data_packet(packet)
else: else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command): def on_hci_command_packet(self, command: HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}') logger.warning(f'!!! unexpected command packet: {command}')
def on_hci_event_packet(self, event): def on_hci_event_packet(self, event: HCI_Event) -> None:
handler_name = f'on_{event.name.lower()}' handler_name = f'on_{event.name.lower()}'
handler = getattr(self, handler_name, self.on_hci_event) handler = getattr(self, handler_name, self.on_hci_event)
handler(event) handler(event)
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
# Look for the connection to which this data belongs # Look for the connection to which this data belongs
if connection := self.connections.get(packet.connection_handle): if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet) connection.on_hci_acl_data_packet(packet)
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu) self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(self, event): def on_command_processed(self, event):

View File

@@ -33,6 +33,7 @@ from typing import (
Union, Union,
Deque, Deque,
Iterable, Iterable,
SupportsBytes,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -47,6 +48,7 @@ from .hci import (
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection from bumble.device import Connection
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -728,7 +730,7 @@ class Channel(EventEmitter):
def __init__( def __init__(
self, self,
manager: 'ChannelManager', manager: ChannelManager,
connection: Connection, connection: Connection,
signaling_cid: int, signaling_cid: int,
psm: int, psm: int,
@@ -755,13 +757,13 @@ class Channel(EventEmitter):
) )
self.state = new_state self.state = new_state
def send_pdu(self, pdu) -> None: def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame) self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request) -> bytes: async def send_request(self, request: SupportsBytes) -> bytes:
# Check that there isn't already a request pending # Check that there isn't already a request pending
if self.response: if self.response:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
@@ -772,7 +774,7 @@ class Channel(EventEmitter):
self.send_pdu(request) self.send_pdu(request)
return await self.response return await self.response
def on_pdu(self, pdu) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.response: if self.response:
self.response.set_result(pdu) self.response.set_result(pdu)
self.response = None self.response = None
@@ -1041,7 +1043,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __init__( def __init__(
self, self,
manager: 'ChannelManager', manager: ChannelManager,
connection: Connection, connection: Connection,
le_psm: int, le_psm: int,
source_cid: int, source_cid: int,
@@ -1096,10 +1098,10 @@ class LeConnectionOrientedChannel(EventEmitter):
elif new_state == self.DISCONNECTED: elif new_state == self.DISCONNECTED:
self.emit('close') self.emit('close')
def send_pdu(self, pdu) -> None: def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self) -> LeConnectionOrientedChannel: async def connect(self) -> LeConnectionOrientedChannel:
@@ -1154,7 +1156,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if self.state == self.CONNECTED: if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED) self.change_state(self.DISCONNECTED)
def on_pdu(self, pdu) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.sink is None: if self.sink is None:
logger.warning('received pdu without a sink') logger.warning('received pdu without a sink')
return return
@@ -1384,6 +1386,7 @@ class ChannelManager:
] ]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
def __init__( def __init__(
self, self,
@@ -1407,11 +1410,12 @@ class ChannelManager:
self.connectionless_mtu = connectionless_mtu self.connectionless_mtu = connectionless_mtu
@property @property
def host(self): def host(self) -> Host:
assert self._host
return self._host return self._host
@host.setter @host.setter
def host(self, host): def host(self, host: Host) -> None:
if self._host is not None: if self._host is not None:
self._host.remove_listener('disconnection', self.on_disconnection) self._host.remove_listener('disconnection', self.on_disconnection)
self._host = host self._host = host
@@ -1565,7 +1569,7 @@ class ChannelManager:
if connection_handle in self.identifiers: if connection_handle in self.identifiers:
del self.identifiers[connection_handle] del self.identifiers[connection_handle]
def send_pdu(self, connection, cid: int, pdu) -> None: def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1574,7 +1578,7 @@ class ChannelManager:
) )
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu) -> None: def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object # Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu) control_frame = L2CAP_Control_Frame.from_bytes(pdu)
@@ -1596,7 +1600,7 @@ class ChannelManager:
channel.on_pdu(pdu) channel.on_pdu(pdu)
def send_control_frame( def send_control_frame(
self, connection: Connection, cid: int, control_frame self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
) -> None: ) -> None:
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
@@ -1605,7 +1609,9 @@ class ChannelManager:
) )
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection: Connection, cid: int, control_frame) -> None: def on_control_frame(
self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
) -> None:
logger.debug( logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) ' f'on connection [0x{connection.handle:04X}] (CID={cid}) '