forked from auracaster/bumble_mirror
Typing packet transmission flow
This commit is contained in:
@@ -652,7 +652,7 @@ class Connection(CompositeEventEmitter):
|
||||
def is_incomplete(self) -> bool:
|
||||
return self.handle is None
|
||||
|
||||
def send_l2cap_pdu(self, cid, pdu):
|
||||
def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(self.handle, cid, pdu)
|
||||
|
||||
def create_l2cap_connector(self, psm):
|
||||
@@ -1096,7 +1096,7 @@ class Device(CompositeEventEmitter):
|
||||
return self._host
|
||||
|
||||
@host.setter
|
||||
def host(self, host):
|
||||
def host(self, host: Host) -> None:
|
||||
# Unsubscribe from events from the current host
|
||||
if self._host:
|
||||
for event_name in device_host_event_handlers:
|
||||
@@ -1183,7 +1183,7 @@ class Device(CompositeEventEmitter):
|
||||
connection, psm, max_credits, mtu, mps
|
||||
)
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle, cid, pdu):
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||
|
||||
async def send_command(self, command, check_result=False):
|
||||
@@ -3167,7 +3167,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_l2cap_pdu(self, connection, cid, pdu):
|
||||
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes):
|
||||
self.l2cap_channel_manager.on_pdu(connection, cid, pdu)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ import struct
|
||||
import collections
|
||||
import logging
|
||||
import functools
|
||||
from typing import Dict, Type, Union
|
||||
from typing import Dict, Type, Union, Callable, Any, Optional
|
||||
|
||||
from .colors import color
|
||||
from .core import (
|
||||
@@ -1918,7 +1918,7 @@ class HCI_Packet:
|
||||
hci_packet_type: int
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet):
|
||||
def from_bytes(packet: bytes) -> HCI_Packet:
|
||||
packet_type = packet[0]
|
||||
|
||||
if packet_type == HCI_COMMAND_PACKET:
|
||||
@@ -1992,7 +1992,7 @@ class HCI_Command(HCI_Packet):
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet):
|
||||
def from_bytes(packet: bytes) -> HCI_Command:
|
||||
op_code, length = struct.unpack_from('<HB', packet, 1)
|
||||
parameters = packet[4:]
|
||||
if len(parameters) != length:
|
||||
@@ -2011,7 +2011,7 @@ class HCI_Command(HCI_Packet):
|
||||
HCI_Object.init_from_bytes(self, parameters, 0, fields)
|
||||
return self
|
||||
|
||||
return cls.from_parameters(parameters)
|
||||
return cls.from_parameters(parameters) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def command_name(op_code):
|
||||
@@ -4350,13 +4350,14 @@ class HCI_Event(HCI_Packet):
|
||||
return event_class
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet):
|
||||
def from_bytes(packet: bytes) -> HCI_Event:
|
||||
event_code = packet[1]
|
||||
length = packet[2]
|
||||
parameters = packet[3:]
|
||||
if len(parameters) != length:
|
||||
raise ValueError('invalid packet length')
|
||||
|
||||
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
|
||||
if event_code == HCI_LE_META_EVENT:
|
||||
# We do this dispatch here and not in the subclass in order to avoid call
|
||||
# loops
|
||||
@@ -4373,7 +4374,7 @@ class HCI_Event(HCI_Packet):
|
||||
return HCI_Event(event_code, parameters)
|
||||
|
||||
# Invoke the factory to create a new instance
|
||||
return cls.from_parameters(parameters)
|
||||
return cls.from_parameters(parameters) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_parameters(cls, parameters):
|
||||
@@ -5086,6 +5087,7 @@ class HCI_Command_Complete_Event(HCI_Event):
|
||||
'''
|
||||
|
||||
return_parameters = b''
|
||||
command_opcode: int
|
||||
|
||||
def map_return_parameters(self, return_parameters):
|
||||
'''Map simple 'status' return parameters to their named constant form'''
|
||||
@@ -5605,7 +5607,7 @@ class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HCI_AclDataPacket:
|
||||
class HCI_AclDataPacket(HCI_Packet):
|
||||
'''
|
||||
See Bluetooth spec @ 5.4.2 HCI ACL Data Packets
|
||||
'''
|
||||
@@ -5613,7 +5615,7 @@ class HCI_AclDataPacket:
|
||||
hci_packet_type = HCI_ACL_DATA_PACKET
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet):
|
||||
def from_bytes(packet: bytes) -> HCI_AclDataPacket:
|
||||
# Read the header
|
||||
h, data_total_length = struct.unpack_from('<HH', packet, 1)
|
||||
connection_handle = h & 0xFFF
|
||||
@@ -5655,12 +5657,14 @@ class HCI_AclDataPacket:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HCI_AclDataPacketAssembler:
|
||||
def __init__(self, callback):
|
||||
current_data: Optional[bytes]
|
||||
|
||||
def __init__(self, callback: Callable[[bytes], Any]) -> None:
|
||||
self.callback = callback
|
||||
self.current_data = None
|
||||
self.l2cap_pdu_length = 0
|
||||
|
||||
def feed_packet(self, packet):
|
||||
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
if packet.pb_flag in (
|
||||
HCI_ACL_PB_FIRST_NON_FLUSHABLE,
|
||||
HCI_ACL_PB_FIRST_FLUSHABLE,
|
||||
@@ -5674,6 +5678,7 @@ class HCI_AclDataPacketAssembler:
|
||||
return
|
||||
self.current_data += packet.data
|
||||
|
||||
assert self.current_data is not None
|
||||
if len(self.current_data) == self.l2cap_pdu_length + 4:
|
||||
# The packet is complete, invoke the callback
|
||||
logger.debug(f'<<< ACL PDU: {self.current_data.hex()}')
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
@@ -30,8 +31,8 @@ from bumble import drivers
|
||||
from .hci import (
|
||||
Address,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_COMMAND_COMPLETE_EVENT,
|
||||
HCI_COMMAND_PACKET,
|
||||
HCI_COMMAND_COMPLETE_EVENT,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
|
||||
@@ -45,8 +46,11 @@ from .hci import (
|
||||
HCI_VERSION_BLUETOOTH_CORE_4_0,
|
||||
HCI_AclDataPacket,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Constant,
|
||||
HCI_Error,
|
||||
HCI_Event,
|
||||
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
|
||||
HCI_LE_Long_Term_Key_Request_Reply_Command,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
@@ -95,17 +99,17 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Connection:
|
||||
def __init__(self, host, handle, peer_address, transport):
|
||||
def __init__(self, host: Host, handle: int, peer_address: Address, transport: int):
|
||||
self.host = host
|
||||
self.handle = handle
|
||||
self.peer_address = peer_address
|
||||
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
|
||||
def on_hci_acl_data_packet(self, packet):
|
||||
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.assembler.feed_packet(packet)
|
||||
|
||||
def on_acl_pdu(self, pdu):
|
||||
def on_acl_pdu(self, pdu: bytes) -> None:
|
||||
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
|
||||
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
|
||||
|
||||
@@ -307,7 +311,7 @@ class Host(AbortableEventEmitter):
|
||||
def set_packet_sink(self, sink):
|
||||
self.hci_sink = sink
|
||||
|
||||
def send_hci_packet(self, packet):
|
||||
def send_hci_packet(self, packet: HCI_Packet) -> None:
|
||||
if self.snooper:
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
|
||||
@@ -356,13 +360,13 @@ class Host(AbortableEventEmitter):
|
||||
self.pending_response = None
|
||||
|
||||
# Use this method to send a command from a task
|
||||
def send_command_sync(self, command):
|
||||
async def send_command(command):
|
||||
def send_command_sync(self, command: HCI_Command) -> None:
|
||||
async def send_command(command: HCI_Command) -> None:
|
||||
await self.send_command(command)
|
||||
|
||||
asyncio.create_task(send_command(command))
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle, cid, pdu):
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
|
||||
|
||||
# Send the data to the controller via ACL packets
|
||||
@@ -387,7 +391,7 @@ class Host(AbortableEventEmitter):
|
||||
offset += data_total_length
|
||||
bytes_remaining -= data_total_length
|
||||
|
||||
def queue_acl_packet(self, acl_packet):
|
||||
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
|
||||
self.acl_packet_queue.appendleft(acl_packet)
|
||||
self.check_acl_packet_queue()
|
||||
|
||||
@@ -397,7 +401,7 @@ class Host(AbortableEventEmitter):
|
||||
f'{len(self.acl_packet_queue)} in queue'
|
||||
)
|
||||
|
||||
def check_acl_packet_queue(self):
|
||||
def check_acl_packet_queue(self) -> None:
|
||||
# Send all we can (TODO: support different LE/Classic limits)
|
||||
while (
|
||||
len(self.acl_packet_queue) > 0
|
||||
@@ -443,11 +447,10 @@ class Host(AbortableEventEmitter):
|
||||
]
|
||||
|
||||
# Packet Sink protocol (packets coming from the controller via HCI)
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
hci_packet = HCI_Packet.from_bytes(packet)
|
||||
if self.ready or (
|
||||
hci_packet.hci_packet_type == HCI_EVENT_PACKET
|
||||
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
|
||||
isinstance(hci_packet, HCI_Command_Complete_Event)
|
||||
and hci_packet.command_opcode == HCI_RESET_COMMAND
|
||||
):
|
||||
self.on_hci_packet(hci_packet)
|
||||
@@ -461,36 +464,36 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
self.emit('flush')
|
||||
|
||||
def on_hci_packet(self, packet):
|
||||
def on_hci_packet(self, packet: HCI_Packet) -> None:
|
||||
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
|
||||
|
||||
if self.snooper:
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
|
||||
# If the packet is a command, invoke the handler for this packet
|
||||
if packet.hci_packet_type == HCI_COMMAND_PACKET:
|
||||
if isinstance(packet, HCI_Command):
|
||||
self.on_hci_command_packet(packet)
|
||||
elif packet.hci_packet_type == HCI_EVENT_PACKET:
|
||||
elif isinstance(packet, HCI_Event):
|
||||
self.on_hci_event_packet(packet)
|
||||
elif packet.hci_packet_type == HCI_ACL_DATA_PACKET:
|
||||
elif isinstance(packet, HCI_AclDataPacket):
|
||||
self.on_hci_acl_data_packet(packet)
|
||||
else:
|
||||
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
|
||||
|
||||
def on_hci_command_packet(self, command):
|
||||
def on_hci_command_packet(self, command: HCI_Command) -> None:
|
||||
logger.warning(f'!!! unexpected command packet: {command}')
|
||||
|
||||
def on_hci_event_packet(self, event):
|
||||
def on_hci_event_packet(self, event: HCI_Event) -> None:
|
||||
handler_name = f'on_{event.name.lower()}'
|
||||
handler = getattr(self, handler_name, self.on_hci_event)
|
||||
handler(event)
|
||||
|
||||
def on_hci_acl_data_packet(self, packet):
|
||||
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
# Look for the connection to which this data belongs
|
||||
if connection := self.connections.get(packet.connection_handle):
|
||||
connection.on_hci_acl_data_packet(packet)
|
||||
|
||||
def on_l2cap_pdu(self, connection, cid, pdu):
|
||||
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
self.emit('l2cap_pdu', connection.handle, cid, pdu)
|
||||
|
||||
def on_command_processed(self, event):
|
||||
|
||||
@@ -33,6 +33,7 @@ from typing import (
|
||||
Union,
|
||||
Deque,
|
||||
Iterable,
|
||||
SupportsBytes,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,7 @@ from .hci import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -728,7 +730,7 @@ class Channel(EventEmitter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: 'ChannelManager',
|
||||
manager: ChannelManager,
|
||||
connection: Connection,
|
||||
signaling_cid: int,
|
||||
psm: int,
|
||||
@@ -755,13 +757,13 @@ class Channel(EventEmitter):
|
||||
)
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu) -> None:
|
||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame) -> None:
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
|
||||
async def send_request(self, request) -> bytes:
|
||||
async def send_request(self, request: SupportsBytes) -> bytes:
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
@@ -772,7 +774,7 @@ class Channel(EventEmitter):
|
||||
self.send_pdu(request)
|
||||
return await self.response
|
||||
|
||||
def on_pdu(self, pdu) -> None:
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.response:
|
||||
self.response.set_result(pdu)
|
||||
self.response = None
|
||||
@@ -1041,7 +1043,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: 'ChannelManager',
|
||||
manager: ChannelManager,
|
||||
connection: Connection,
|
||||
le_psm: int,
|
||||
source_cid: int,
|
||||
@@ -1096,10 +1098,10 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
elif new_state == self.DISCONNECTED:
|
||||
self.emit('close')
|
||||
|
||||
def send_pdu(self, pdu) -> None:
|
||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame) -> None:
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
|
||||
|
||||
async def connect(self) -> LeConnectionOrientedChannel:
|
||||
@@ -1154,7 +1156,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
if self.state == self.CONNECTED:
|
||||
self.change_state(self.DISCONNECTED)
|
||||
|
||||
def on_pdu(self, pdu) -> None:
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.sink is None:
|
||||
logger.warning('received pdu without a sink')
|
||||
return
|
||||
@@ -1384,6 +1386,7 @@ class ChannelManager:
|
||||
]
|
||||
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
|
||||
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
|
||||
_host: Optional[Host]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1407,11 +1410,12 @@ class ChannelManager:
|
||||
self.connectionless_mtu = connectionless_mtu
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
def host(self) -> Host:
|
||||
assert self._host
|
||||
return self._host
|
||||
|
||||
@host.setter
|
||||
def host(self, host):
|
||||
def host(self, host: Host) -> None:
|
||||
if self._host is not None:
|
||||
self._host.remove_listener('disconnection', self.on_disconnection)
|
||||
self._host = host
|
||||
@@ -1565,7 +1569,7 @@ class ChannelManager:
|
||||
if connection_handle in self.identifiers:
|
||||
del self.identifiers[connection_handle]
|
||||
|
||||
def send_pdu(self, connection, cid: int, pdu) -> None:
|
||||
def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
|
||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||
@@ -1574,7 +1578,7 @@ class ChannelManager:
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
|
||||
|
||||
def on_pdu(self, connection: Connection, cid: int, pdu) -> None:
|
||||
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
|
||||
# Parse the L2CAP payload into a Control Frame object
|
||||
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
|
||||
@@ -1596,7 +1600,7 @@ class ChannelManager:
|
||||
channel.on_pdu(pdu)
|
||||
|
||||
def send_control_frame(
|
||||
self, connection: Connection, cid: int, control_frame
|
||||
self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
|
||||
@@ -1605,7 +1609,9 @@ class ChannelManager:
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
|
||||
|
||||
def on_control_frame(self, connection: Connection, cid: int, control_frame) -> None:
|
||||
def on_control_frame(
|
||||
self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
|
||||
Reference in New Issue
Block a user