Typing helper

This commit is contained in:
Josh Wu
2023-11-27 16:10:01 +08:00
parent f0e5cdee1a
commit f3cd8f8ed0
3 changed files with 79 additions and 42 deletions

View File

@@ -5296,6 +5296,10 @@ class HCI_Disconnection_Complete_Event(HCI_Event):
See Bluetooth spec @ 7.7.5 Disconnection Complete Event See Bluetooth spec @ 7.7.5 Disconnection Complete Event
''' '''
status: int
connection_handle: int
reason: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)]) @HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)])

View File

@@ -15,30 +15,39 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import cast, Any
import logging import logging
from .colors import color from bumble import avdtp
from .att import ATT_CID, ATT_PDU from bumble.colors import color
from .smp import SMP_CID, SMP_Command from bumble.att import ATT_CID, ATT_PDU
from .core import name_or_number from bumble.smp import SMP_CID, SMP_Command
from .l2cap import ( from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST, L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE, L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response, L2CAP_Connection_Response,
) )
from .hci import ( from bumble.hci import (
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
) )
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM from bumble.sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -50,23 +59,25 @@ logger = logging.getLogger(__name__)
PSM_NAMES = { PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM', RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP', SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP' avdtp.AVDTP_PSM: 'AVDTP',
# TODO: add more PSM values
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
def __init__(self, analyzer): psms: MutableMapping[int, int]
peer: PacketTracer.AclStream
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks # pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
if l2cap_pdu.cid == ATT_CID: if l2cap_pdu.cid == ATT_CID:
@@ -81,26 +92,30 @@ class PacketTracer:
# Check if this signals a new channel # Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if ( if (
control_frame.result connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
): ):
if self.peer: if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid): if psm := self.peer.psms.get(
connection_response.source_cid
):
# Found a pending connection # Found a pending connection
self.psms[control_frame.destination_cid] = psm self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for # For AVDTP connections, create a packet assembler for
# each direction # each direction
if psm == AVDTP_PSM: if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[ self.avdtp_assemblers[
control_frame.source_cid connection_response.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message) ] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[ self.peer.avdtp_assemblers[
control_frame.destination_cid connection_response.destination_cid
] = AVDTP_MessageAssembler( ] = avdtp.MessageAssembler(
self.peer.on_avdtp_message self.peer.on_avdtp_message
) )
@@ -113,7 +128,7 @@ class PacketTracer:
elif psm == RFCOMM_PSM: elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM: elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit( self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}' f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
@@ -130,22 +145,26 @@ class PacketTracer:
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message): def on_avdtp_message(
self, transaction_label: int, message: avdtp.Message
) -> None:
self.analyzer.emit( self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}' f'{color("AVDTP", "green")} [{transaction_label}] {message}'
) )
def feed_packet(self, packet): def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
class Analyzer: class Analyzer:
def __init__(self, label, emit_message): acl_streams: MutableMapping[int, PacketTracer.AclStream]
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle): def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info( logger.info(
f'[{self.label}] +++ Creating ACL stream for connection ' f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -160,7 +179,7 @@ class PacketTracer:
return stream return stream
def end_acl_stream(self, connection_handle): def end_acl_stream(self, connection_handle: int) -> None:
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info( logger.info(
f'[{self.label}] --- Removing ACL stream for connection ' f'[{self.label}] --- Removing ACL stream for connection '
@@ -171,23 +190,29 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle) self.peer.end_acl_stream(connection_handle)
def on_packet(self, packet): def on_packet(self, packet: HCI_Packet) -> None:
self.emit(packet) self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET: if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the # Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle # first ACL packet for that connection handle
if (stream := self.acl_streams.get(packet.connection_handle)) is None: if (
stream = self.start_acl_stream(packet.connection_handle) stream := self.acl_streams.get(acl_packet.connection_handle)
stream.feed_packet(packet) ) is None:
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif packet.hci_packet_type == HCI_EVENT_PACKET:
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT: event_packet = cast(HCI_Event, packet)
self.end_acl_stream(packet.connection_handle) if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message): def emit(self, message: Any) -> None:
self.emit_message(f'[{self.label}] {message}') self.emit_message(f'[{self.label}] {message}')
def trace(self, packet, direction=0): def trace(self, packet: HCI_Packet, direction: int = 0) -> None:
if direction == 0: if direction == 0:
self.host_to_controller_analyzer.on_packet(packet) self.host_to_controller_analyzer.on_packet(packet)
else: else:
@@ -195,10 +220,10 @@ class PacketTracer:
def __init__( def __init__(
self, self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'), host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'), controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info, emit_message: Callable[..., None] = logger.info,
): ) -> None:
self.host_to_controller_analyzer = PacketTracer.Analyzer( self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message host_to_controller_label, emit_message
) )

View File

@@ -391,6 +391,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
''' '''
psm: int
source_cid: int
@staticmethod @staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2 psm_length = 2
@@ -432,6 +435,11 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
''' '''
source_cid: int
destination_cid: int
status: int
result: int
CONNECTION_SUCCESSFUL = 0x0000 CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001 CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002 CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002