diff --git a/bumble/device.py b/bumble/device.py index ff49576..557602c 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -2080,6 +2080,12 @@ class DeviceConfiguration: io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT gap_service_enabled: bool = True gatt_service_enabled: bool = True + enhanced_retransmission_supported: bool = False + l2cap_extended_features: Sequence[int] = ( + l2cap.L2CAP_Information_Request.ExtendedFeatures.FIXED_CHANNELS, + l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION, + l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE, + ) def __post_init__(self) -> None: self.gatt_services: list[dict[str, Any]] = [] @@ -2349,6 +2355,10 @@ class Device(utils.CompositeEventEmitter): ) -> None: super().__init__() + # Use the initial config or a default + config = config or DeviceConfiguration() + self.config = config + self._host = None self.powered_on = False self.auto_restart_inquiry = True @@ -2356,7 +2366,7 @@ class Device(utils.CompositeEventEmitter): self.gatt_server = gatt_server.Server(self) self.sdp_server = sdp.Server(self) self.l2cap_channel_manager = l2cap.ChannelManager( - [l2cap.L2CAP_Information_Request.EXTENDED_FEATURE_FIXED_CHANNELS] + config.l2cap_extended_features ) self.advertisement_accumulators = {} # Accumulators, by address self.periodic_advertising_syncs = [] @@ -2387,10 +2397,6 @@ class Device(utils.CompositeEventEmitter): # Own address type cache self.connect_own_address_type = None - # Use the initial config or a default - config = config or DeviceConfiguration() - self.config = config - self.name = config.name self.public_address = hci.Address.ANY self.random_address = config.address diff --git a/bumble/host.py b/bumble/host.py index bc8a792..b470eee 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -707,7 +707,7 @@ class Host(utils.EventEmitter): asyncio.create_task(send_command(command)) - def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: + def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None: if not (connection := self.connections.get(connection_handle)): logger.warning(f'connection 0x{connection_handle:04X} not found') return @@ -718,27 +718,24 @@ class Host(utils.EventEmitter): ) return - # Create a PDU - l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) - # Send the data to the controller via ACL packets - bytes_remaining = len(l2cap_pdu) - offset = 0 - pb_flag = 0 - while bytes_remaining: - data_total_length = min(bytes_remaining, packet_queue.max_packet_size) + max_packet_size = packet_queue.max_packet_size + for offset in range(0, len(sdu), max_packet_size): + pdu = sdu[offset : offset + max_packet_size] acl_packet = hci.HCI_AclDataPacket( connection_handle=connection_handle, - pb_flag=pb_flag, + pb_flag=1 if offset > 0 else 0, bc_flag=0, - data_total_length=data_total_length, - data=l2cap_pdu[offset : offset + data_total_length], + data_total_length=len(pdu), + data=pdu, + ) + logger.debug( + '>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu ) - logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}') packet_queue.enqueue(acl_packet, connection_handle) - pb_flag = 1 - offset += data_total_length - bytes_remaining -= data_total_length + + def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: + self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu))) def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None: if connection := self.connections.get(connection_handle): diff --git a/bumble/l2cap.py b/bumble/l2cap.py index f770866..72b39f7 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -23,18 +23,10 @@ import enum import logging import struct from collections import deque -from collections.abc import Sequence -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Iterable, - Optional, - SupportsBytes, - TypeVar, - Union, -) +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Optional, SupportsBytes, TypeVar, Union + +from typing_extensions import override from bumble import hci, utils from bumble.colors import color @@ -69,7 +61,12 @@ L2CAP_MIN_LE_MTU = 23 L2CAP_MIN_BR_EDR_MTU = 48 L2CAP_MAX_BR_EDR_MTU = 65535 -L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept +L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept +L2CAP_DEFAULT_MPS = 1010 # Default value for the MPS we are willing to accept +DEFAULT_TX_WINDOW_SIZE = 63 +DEFAULT_MAX_RETRANSMISSION = 1 +DEFAULT_RETRANSMISSION_TIMEOUT = 2.0 +DEFAULT_MONITOR_TIMEOUT = 12.0 L2CAP_DEFAULT_CONNECTIONLESS_MTU = 1024 @@ -133,24 +130,60 @@ L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256 -L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01 - -L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01 - # fmt: on # pylint: enable=line-too-long +class TransmissionMode(utils.OpenIntEnum): + '''See Bluetooth spec @ Vol 3, Part A - 5.4. Retransmission and Flow Control option''' + + BASIC = 0x00 + RETRANSMISSION = 0x01 + FLOW_CONTROL = 0x02 + ENHANCED_RETRANSMISSION = 0x03 + STREAMING = 0x04 + + # ----------------------------------------------------------------------------- # Classes # ----------------------------------------------------------------------------- # pylint: disable=invalid-name +class L2capError(ProtocolError): + def __init__(self, error_code, error_name='', details=''): + super().__init__(error_code, 'L2CAP', error_name, details) + + @dataclasses.dataclass class ClassicChannelSpec: + '''Spec of L2CAP Channel over Classic Transport. + + Attributes: + psm: PSM of channel. This is optional for server, and when it is None, a PSM + will be allocated. + mtu: Maximum Transmission Unit. + mps: Maximum PDU payload Size. + tx_window_size: The size of the transmission window for Flow Control mode, + Retransmission mode, and Enhanced Retransmission mode. + max_retransmission: The number of transmissions of a single I-frame that L2CAP + is allowed to try in Retransmission mode and Enhanced Retransmission mode. + retransmission_timeout: The timeout of retransmission in seconds. + monitor_timeout: The interval at which S-frames should be transmitted on the + return channel when no frames are received on the forward channel. + mode: The transmission mode to use. + fcs_enabled: Whether to enable FCS (Frame Check Sequence). + ''' + psm: Optional[int] = None mtu: int = L2CAP_DEFAULT_MTU + mps: int = L2CAP_DEFAULT_MPS + tx_window_size: int = DEFAULT_TX_WINDOW_SIZE + max_retransmission: int = DEFAULT_MAX_RETRANSMISSION + retransmission_timeout: float = DEFAULT_RETRANSMISSION_TIMEOUT + monitor_timeout: float = DEFAULT_MONITOR_TIMEOUT + mode: TransmissionMode = TransmissionMode.BASIC + fcs_enabled: bool = False @dataclasses.dataclass @@ -183,20 +216,29 @@ class L2CAP_PDU: See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT ''' - @staticmethod - def from_bytes(data: bytes) -> L2CAP_PDU: + @classmethod + def from_bytes(cls, data: bytes) -> L2CAP_PDU: # Check parameters if len(data) < 4: raise InvalidPacketError('not enough data for L2CAP header') - _, l2cap_pdu_cid = struct.unpack_from(' bytes: - header = struct.pack(' bytes: + length = len(self.payload) + if with_fcs: + length += 2 + header = struct.pack(' None: self.cid = cid @@ -206,6 +248,120 @@ class L2CAP_PDU: return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}' +class ControlField: + ''' + See Bluetooth spec @ Vol 3, Part A - 3.3.2 Control field. + ''' + + class FieldType(utils.OpenIntEnum): + I_FRAME = 0x00 + S_FRAME = 0x01 + + class SegmentationAndReassembly(utils.OpenIntEnum): + UNSEGMENTED = 0x00 + START = 0x01 + END = 0x02 + CONTINUATION = 0x03 + + class SupervisoryFunction(utils.OpenIntEnum): + # Receiver Ready + RR = 0 + # Reject + REJ = 1 + # Receiver Not Ready + RNR = 2 + # Select Reject + SREJ = 3 + + class RetransmissionBit(utils.OpenIntEnum): + NORMAL = 0x00 + RETRANSMISSION = 0x01 + + req_seq: int + frame_type: ClassVar[FieldType] + + def __bytes__(self) -> bytes: + raise NotImplementedError() + + +class EnhancedControlField(ControlField): + """Base control field used in Enhanced Retransmission and Streaming Mode.""" + + final: int + + @classmethod + def from_bytes(cls, data: bytes) -> EnhancedControlField: + frame_type = data[0] & 0x01 + if frame_type == cls.FieldType.I_FRAME: + return InformationEnhancedControlField.from_bytes(data) + elif frame_type == cls.FieldType.S_FRAME: + return SupervisoryEnhancedControlField.from_bytes(data) + else: + raise InvalidArgumentError(f'Invalid frame type: {frame_type}') + + +@dataclasses.dataclass +class InformationEnhancedControlField(EnhancedControlField): + tx_seq: int = 0 + req_seq: int = 0 + segmentation_and_reassembly: int = ( + EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED + ) + final: int = 1 + + frame_type = EnhancedControlField.FieldType.I_FRAME + + @classmethod + def from_bytes(cls, data: bytes) -> EnhancedControlField: + return cls( + tx_seq=(data[0] >> 1) & 0b0111111, + final=(data[0] >> 7) & 0b1, + req_seq=(data[1] & 0b001111111), + segmentation_and_reassembly=(data[1] >> 6) & 0b11, + ) + + def __bytes__(self) -> bytes: + return bytes( + [ + self.frame_type | (self.tx_seq << 1) | (self.final << 7), + self.req_seq | (self.segmentation_and_reassembly << 6), + ] + ) + + +@dataclasses.dataclass +class SupervisoryEnhancedControlField(EnhancedControlField): + + supervision_function: int = ControlField.SupervisoryFunction.RR + poll: int = 0 + req_seq: int = 0 + final: int = 0 + + frame_type = EnhancedControlField.FieldType.S_FRAME + + @classmethod + def from_bytes(cls, data: bytes) -> EnhancedControlField: + return cls( + supervision_function=(data[0] >> 2) & 0b11, + poll=(data[0] >> 4) & 0b1, + final=(data[0] >> 7) & 0b1, + req_seq=(data[1] & 0b1111111), + ) + + def __bytes__(self) -> bytes: + return bytes( + [ + ( + self.frame_type + | (self.supervision_function << 2) + | self.poll << 7 + | (self.final << 7) + ), + self.req_seq, + ] + ) + + # ----------------------------------------------------------------------------- @dataclasses.dataclass class L2CAP_Control_Frame: @@ -248,14 +404,16 @@ class L2CAP_Control_Frame: return frame @staticmethod - def decode_configuration_options(data: bytes) -> list[tuple[int, bytes]]: + def decode_configuration_options( + data: bytes, + ) -> list[tuple[L2CAP_Configure_Request.ParameterType, bytes]]: options = [] while len(data) >= 2: value_type = data[0] length = data[1] value = data[2 : 2 + length] data = data[2 + length :] - options.append((value_type, value)) + options.append((L2CAP_Configure_Request.ParameterType(value_type), value)) return options @@ -398,6 +556,15 @@ class L2CAP_Configure_Request(L2CAP_Control_Frame): See Bluetooth spec @ Vol 3, Part A - 4.4 CONFIGURATION REQUEST ''' + class ParameterType(utils.OpenIntEnum): + MTU = 0x01 + FLUSH_TIMEOUT = 0x02 + QOS = 0x03 + RETRANSMISSION_AND_FLOW_CONTROL = 0x04 + FCS = 0x05 + EXTENDED_FLOW_SPEC = 0x06 + EXTENDED_WINDOW_SIZE = 0x07 + destination_cid: int = dataclasses.field(metadata=hci.metadata(2)) flags: int = dataclasses.field(metadata=hci.metadata(2)) options: bytes = dataclasses.field(metadata=hci.metadata('*')) @@ -484,17 +651,18 @@ class L2CAP_Information_Request(L2CAP_Control_Frame): EXTENDED_FEATURES_SUPPORTED = 0x0002 FIXED_CHANNELS_SUPPORTED = 0x0003 - EXTENDED_FEATURE_FLOW_MODE_CONTROL = 0x0001 - EXTENDED_FEATURE_RETRANSMISSION_MODE = 0x0002 - EXTENDED_FEATURE_BIDIRECTIONAL_QOS = 0x0004 - EXTENDED_FEATURE_ENHANCED_RETRANSMISSION_MODE = 0x0008 - EXTENDED_FEATURE_STREAMING_MODE = 0x0010 - EXTENDED_FEATURE_FCS_OPTION = 0x0020 - EXTENDED_FEATURE_EXTENDED_FLOW_SPEC = 0x0040 - EXTENDED_FEATURE_FIXED_CHANNELS = 0x0080 - EXTENDED_FEATURE_EXTENDED_WINDOW_SIZE = 0x0100 - EXTENDED_FEATURE_UNICAST_CONNECTIONLESS_DATA = 0x0200 - EXTENDED_FEATURE_ENHANCED_CREDIT_BASE_FLOW_CONTROL = 0x0400 + class ExtendedFeatures(hci.SpecableFlag): + FLOW_MODE_CONTROL = 0x0001 + RETRANSMISSION_MODE = 0x0002 + BIDIRECTIONAL_QOS = 0x0004 + ENHANCED_RETRANSMISSION_MODE = 0x0008 + STREAMING_MODE = 0x0010 + FCS_OPTION = 0x0020 + EXTENDED_FLOW_SPEC = 0x0040 + FIXED_CHANNELS = 0x0080 + EXTENDED_WINDOW_SIZE = 0x0100 + UNICAST_CONNECTIONLESS_DATA = 0x0200 + ENHANCED_CREDIT_BASE_FLOW_CONTROL = 0x0400 info_type: int = dataclasses.field(metadata=InfoType.type_metadata(2)) @@ -702,6 +870,218 @@ class L2CAP_Credit_Based_Reconfigure_Response(L2CAP_Control_Frame): result: int = dataclasses.field(metadata=Result.type_metadata(2)) +# ----------------------------------------------------------------------------- +class Processor: + def __init__(self, channel: ClassicChannel) -> None: + self.channel = channel + + def send_sdu(self, sdu: bytes) -> None: + self.channel.send_pdu(sdu) + + def on_pdu(self, pdu: bytes) -> None: + self.channel.on_sdu(pdu) + + +# TODO: Handle retransmission +class EnhancedRetransmissionProcessor(Processor): + + MAX_SEQ_NUM = 64 + + @dataclasses.dataclass + class _PendingPdu: + payload: bytes + tx_seq: int + req_seq: int = 0 + + def __bytes__(self) -> bytes: + return ( + bytes( + InformationEnhancedControlField( + tx_seq=self.tx_seq, req_seq=self.req_seq + ) + ) + + self.payload + ) + + _expected_ack_seq: int = 0 + _next_tx_seq: int = 0 + _last_tx_seq: int = 0 + _req_seq_num: int = 0 + _next_seq_num: int = 0 + _remote_is_busy: bool = False + + _num_receiver_ready_polls_sent: int = 0 + _pending_pdus: list[_PendingPdu] + _monitor_handle: Optional[asyncio.TimerHandle] = None + _receiver_ready_poll_handle: Optional[asyncio.TimerHandle] = None + + # Timeout, in seconds. + monitor_timeout: float + retransmission_timeout: float + + @classmethod + def _num_frames_between(cls, low: int, high: int) -> int: + if high < low: + high += cls.MAX_SEQ_NUM + return high - low + + def __init__( + self, + channel: ClassicChannel, + peer_tx_window_size: int = DEFAULT_TX_WINDOW_SIZE, + peer_max_retransmission: int = DEFAULT_MAX_RETRANSMISSION, + peer_mps: int = L2CAP_DEFAULT_MPS, + ): + spec = channel.spec + self.mps = spec.mps + self.peer_mps = peer_mps + self.peer_tx_window_size = peer_tx_window_size + self._pending_pdus = [] + self.monitor_timeout = spec.monitor_timeout + self.channel = channel + self.retransmission_timeout = spec.retransmission_timeout + self.peer_max_retransmission = peer_max_retransmission + + def _monitor(self) -> None: + if ( + self.peer_max_retransmission <= 0 + or self._num_receiver_ready_polls_sent < self.peer_max_retransmission + ): + self._send_receiver_ready_poll() + self._start_monitor() + else: + logger.error("Max retransmission exceeded") + + def _receiver_ready_poll(self) -> None: + self._send_receiver_ready_poll() + self._start_monitor() + + def _start_monitor(self) -> None: + if self._monitor_handle: + self._monitor_handle.cancel() + self._monitor_handle = asyncio.get_running_loop().call_later( + self.monitor_timeout, self._monitor + ) + + def _start_receiver_ready_poll(self) -> None: + if self._receiver_ready_poll_handle: + self._receiver_ready_poll_handle.cancel() + self._num_receiver_ready_polls_sent = 0 + + self._receiver_ready_poll_handle = asyncio.get_running_loop().call_later( + self.retransmission_timeout, self._receiver_ready_poll + ) + + def _send_receiver_ready_poll(self) -> None: + self._num_receiver_ready_polls_sent += 1 + self.channel.send_pdu( + SupervisoryEnhancedControlField( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + final=1, + req_seq=self._next_seq_num, + ) + ) + + def _get_next_tx_seq(self) -> int: + seq_num = self._next_tx_seq + self._next_tx_seq = (self._next_tx_seq + 1) % self.MAX_SEQ_NUM + return seq_num + + @override + def send_sdu(self, sdu: bytes) -> None: + if len(sdu) > self.peer_mps: + raise InvalidArgumentError( + f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}' + ) + pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq()) + self._pending_pdus.append(pdu) + self._process_output() + + @override + def on_pdu(self, pdu: bytes) -> None: + control_field = EnhancedControlField.from_bytes(pdu) + self._update_ack_seq(control_field.req_seq, control_field.final != 0) + if isinstance(control_field, InformationEnhancedControlField): + if control_field.tx_seq != self._next_seq_num: + return + self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM + self._req_seq_num = self._next_seq_num + + ack_frame = SupervisoryEnhancedControlField( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + req_seq=self._next_seq_num, + ) + self.channel.send_pdu(ack_frame) + self.channel.on_sdu(pdu[2:]) + elif isinstance(control_field, SupervisoryEnhancedControlField): + self._remote_is_busy = ( + control_field.supervision_function + == SupervisoryEnhancedControlField.SupervisoryFunction.RNR + ) + + if control_field.supervision_function in ( + SupervisoryEnhancedControlField.SupervisoryFunction.RR, + SupervisoryEnhancedControlField.SupervisoryFunction.RNR, + ): + if control_field.poll: + self.channel.send_pdu( + SupervisoryEnhancedControlField( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + final=1, + req_seq=self._next_seq_num, + ) + ) + else: + # TODO: Handle Retransmission. + pass + + def _process_output(self) -> None: + if self._remote_is_busy or self._monitor_handle: + return + + for pdu in self._pending_pdus: + if self._num_unacked_frames >= self.peer_tx_window_size: + return + self._send_pdu(pdu) + self._last_tx_seq = pdu.tx_seq + + @property + def _num_unacked_frames(self) -> int: + if not self._pending_pdus: + return 0 + return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1) + + def _send_pdu(self, pdu: _PendingPdu) -> None: + pdu.req_seq = self._req_seq_num + + self._start_receiver_ready_poll() + self.channel.send_pdu(bytes(pdu)) + + def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None: + num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq) + if num_frames_acked > self._num_unacked_frames: + logger.error( + "Received acknowledgment for %d frames but only %d frames are pending", + num_frames_acked, + self._num_unacked_frames, + ) + return + if is_poll_response and self._monitor_handle: + self._monitor_handle.cancel() + self._monitor_handle = None + + del self._pending_pdus[:num_frames_acked] + self._expected_ack_seq = new_seq + if ( + self._expected_ack_seq == self._next_tx_seq + and self._receiver_ready_poll_handle + ): + self._receiver_ready_poll_handle.cancel() + self._receiver_ready_poll_handle = None + + self._process_output() + + # ----------------------------------------------------------------------------- class ClassicChannel(utils.EventEmitter): class State(enum.IntEnum): @@ -739,6 +1119,7 @@ class ClassicChannel(utils.EventEmitter): connection: Connection mtu: int peer_mtu: int + processor: Processor def __init__( self, @@ -747,14 +1128,14 @@ class ClassicChannel(utils.EventEmitter): signaling_cid: int, psm: int, source_cid: int, - mtu: int, + spec: ClassicChannelSpec, ) -> None: super().__init__() self.manager = manager self.connection = connection self.signaling_cid = signaling_cid self.state = self.State.CLOSED - self.mtu = mtu + self.mtu = spec.mtu self.peer_mtu = L2CAP_MIN_BR_EDR_MTU self.psm = psm self.source_cid = source_cid @@ -762,26 +1143,47 @@ class ClassicChannel(utils.EventEmitter): self.connection_result = None self.disconnection_result = None self.sink = None + self.fcs_enabled = spec.fcs_enabled + self.spec = spec + self.mode = spec.mode + # Configure mode-specific processor later on configure request. + self.processor = Processor(self) + if self.mode not in ( + TransmissionMode.BASIC, + TransmissionMode.ENHANCED_RETRANSMISSION, + ): + raise InvalidArgumentError(f"Mode {spec.mode} is not supported") def _change_state(self, new_state: State) -> None: logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') self.state = new_state + def write(self, sdu: bytes) -> None: + self.processor.send_sdu(sdu) + def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: if self.state != self.State.OPEN: raise InvalidStateError('channel not open') - self.manager.send_pdu(self.connection, self.destination_cid, pdu) + self.manager.send_pdu( + self.connection, self.destination_cid, pdu, self.fcs_enabled + ) def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: self.manager.send_control_frame(self.connection, self.signaling_cid, frame) def on_pdu(self, pdu: bytes) -> None: + if self.fcs_enabled: + # Drop FCS. + pdu = pdu[:-2] + self.processor.on_pdu(pdu) + + def on_sdu(self, sdu: bytes) -> None: if self.sink: # pylint: disable=not-callable - self.sink(pdu) + self.sink(sdu) else: logger.warning( - color('received pdu without a pending request or sink', 'red') + color('received sdu without a pending request or sink', 'red') ) async def connect(self) -> None: @@ -811,10 +1213,8 @@ class ClassicChannel(utils.EventEmitter): finally: self.connection_result = None - async def disconnect(self) -> None: - if self.state != self.State.OPEN: - raise InvalidStateError('invalid state') - + def _disconnect_sync(self) -> None: + """For internal sync disconnection.""" self._change_state(self.State.WAIT_DISCONNECT) self.send_control_frame( L2CAP_Disconnection_Request( @@ -827,7 +1227,21 @@ class ClassicChannel(utils.EventEmitter): # Create a future to wait for the state machine to get to a success or error # state self.disconnection_result = asyncio.get_running_loop().create_future() - return await self.disconnection_result + + def _abort_connection_result(self, message: str = 'Connection failure') -> None: + # Cancel pending connection result. + if self.connection_result and not self.connection_result.done(): + self.connection_result.set_exception( + L2capError(error_code=0, error_name=message) + ) + + async def disconnect(self) -> None: + if self.state != self.State.OPEN: + raise InvalidStateError('invalid state') + + self._disconnect_sync() + if self.disconnection_result: + return await self.disconnection_result def abort(self) -> None: if self.state == self.State.OPEN: @@ -835,20 +1249,40 @@ class ClassicChannel(utils.EventEmitter): self.emit(self.EVENT_CLOSE) def send_configure_request(self) -> None: - options = L2CAP_Control_Frame.encode_configuration_options( - [ + options: list[tuple[int, bytes]] = [ + ( + L2CAP_Configure_Request.ParameterType.MTU, + struct.pack(' None: - if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT): - self.send_control_frame( - L2CAP_Disconnection_Response( - identifier=request.identifier, - destination_cid=request.destination_cid, - source_cid=request.source_cid, - ) + self.send_control_frame( + L2CAP_Disconnection_Response( + identifier=request.identifier, + destination_cid=request.destination_cid, + source_cid=request.source_cid, ) - self._change_state(self.State.CLOSED) - self.emit(self.EVENT_CLOSE) - self.manager.on_channel_closed(self) - else: - logger.warning(color('invalid state', 'red')) + ) + self._abort_connection_result() + self._change_state(self.State.CLOSED) + self.emit(self.EVENT_CLOSE) + self.manager.on_channel_closed(self) def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None: - if self.state != self.State.WAIT_DISCONNECT: - logger.warning(color('invalid state', 'red')) - return - if ( response.destination_cid != self.destination_cid or response.source_cid != self.source_cid @@ -1242,9 +1760,8 @@ class LeCreditBasedChannel(utils.EventEmitter): self._change_state(self.State.CONNECTED) else: self.connection_result.set_exception( - ProtocolError( + L2capError( response.result, - 'l2cap', L2CAP_LE_Credit_Based_Connection_Response.Result( response.result ).name, @@ -1383,13 +1900,13 @@ class ClassicChannelServer(utils.EventEmitter): manager: ChannelManager, psm: int, handler: Optional[Callable[[ClassicChannel], Any]], - mtu: int, + spec: ClassicChannelSpec, ) -> None: super().__init__() self.manager = manager self.handler = handler self.psm = psm - self.mtu = mtu + self.spec = spec def on_connection(self, channel: ClassicChannel) -> None: self.emit(self.EVENT_CONNECTION, channel) @@ -1462,7 +1979,7 @@ class ChannelManager: ) # LE CoC channels, mapped by connection and destination cid self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM self.le_coc_requests = {} # LE CoC connection requests, by identifier - self.extended_features = extended_features + self.extended_features = set(extended_features) self.connectionless_mtu = connectionless_mtu self.connection_parameters_update_response = None @@ -1566,7 +2083,7 @@ class ChannelManager: raise InvalidArgumentError('invalid PSM') check >>= 8 - self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) + self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec) return self.servers[spec.psm] @@ -1615,7 +2132,13 @@ class ChannelManager: if connection_handle in self.identifiers: del self.identifiers[connection_handle] - def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None: + def send_pdu( + self, + connection: Connection, + cid: int, + pdu: Union[SupportsBytes, bytes], + with_fcs: bool = False, + ) -> None: pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_bytes = bytes(pdu) logger.debug( @@ -1623,7 +2146,9 @@ class ChannelManager: f'on connection [0x{connection.handle:04X}] (CID={cid}) ' f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}' ) - self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes) + self.host.send_acl_sdu( + connection.handle, L2CAP_PDU(cid, bytes(pdu)).to_bytes(with_fcs=with_fcs) + ) def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): @@ -1729,7 +2254,7 @@ class ChannelManager: f'creating server channel with cid={source_cid} for psm {request.psm}' ) channel = ClassicChannel( - self, connection, cid, request.psm, source_cid, server.mtu + self, connection, cid, request.psm, source_cid, server.spec ) connection_channels[source_cid] = channel @@ -2187,12 +2712,12 @@ class ChannelManager: f'creating client channel with cid={source_cid} for psm {spec.psm}' ) channel = ClassicChannel( - self, - connection, - L2CAP_SIGNALING_CID, - spec.psm, - source_cid, - spec.mtu, + manager=self, + connection=connection, + signaling_cid=L2CAP_SIGNALING_CID, + psm=spec.psm, + source_cid=source_cid, + spec=spec, ) connection_channels[source_cid] = channel @@ -2200,7 +2725,27 @@ class ChannelManager: try: await channel.connect() except BaseException as e: - del connection_channels[source_cid] + connection_channels.pop(source_cid, None) raise e return channel + + @classmethod + def make_mode_processor( + self, + channel: ClassicChannel, + mode: TransmissionMode, + peer_tx_window_size: int, + peer_max_retransmission: int, + peer_retransmission_timeout: int, + peer_monitor_timeout: int, + peer_mps: int, + ) -> Processor: + del peer_retransmission_timeout, peer_monitor_timeout # Unused. + if mode == TransmissionMode.BASIC: + return Processor(channel) + elif mode == TransmissionMode.ENHANCED_RETRANSMISSION: + return EnhancedRetransmissionProcessor( + channel, peer_tx_window_size, peer_max_retransmission, peer_mps + ) + raise InvalidArgumentError("Mode %s is not implemented", mode.name) diff --git a/bumble/utils.py b/bumble/utils.py index fcb7429..17bb6c3 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -533,3 +533,20 @@ class IntConvertible(Protocol): def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... + + +# ----------------------------------------------------------------------------- +def crc_16(data: bytes) -> int: + """Calculate CRC-16-IBM of given data. + + Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed) + """ + crc = 0x0000 + for byte in data: + crc ^= byte + for _ in range(8): + if (crc & 0x0001) > 0: + crc = (crc >> 1) ^ 0xA001 + else: + crc = crc >> 1 + return crc diff --git a/examples/run_classic_l2cap.py b/examples/run_classic_l2cap.py new file mode 100644 index 0000000..b780065 --- /dev/null +++ b/examples/run_classic_l2cap.py @@ -0,0 +1,107 @@ +# Copyright 2021-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import argparse +import asyncio +import sys + +import bumble.logging +from bumble import core, l2cap +from bumble.device import Device +from bumble.transport import open_transport + + +# ----------------------------------------------------------------------------- +async def main( + config_file: str, transport: str, mode: int, peer_address: str, psm: int +) -> None: + + print('<<< connecting to HCI...') + async with await open_transport(transport) as hci_transport: + print('<<< connected') + + # Create a device + device = Device.from_config_file_with_hci( + config_file, hci_transport.source, hci_transport.sink + ) + device.classic_enabled = True + device.l2cap_channel_manager.extended_features.add( + l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE + ) + device.l2cap_channel_manager.extended_features.add( + l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION + ) + + # Start the controller + await device.power_on() + + # Start being discoverable and connectable + await device.set_discoverable(True) + await device.set_connectable(True) + + active_channel: l2cap.ClassicChannel | None = None + + def on_connection(channel: l2cap.ClassicChannel): + + def on_sdu(sdu: bytes): + print(f'<<< {sdu.decode()}') + + channel.sink = on_sdu + nonlocal active_channel + active_channel = channel + + server = device.create_l2cap_server( + spec=l2cap.ClassicChannelSpec( + mode=l2cap.TransmissionMode(mode), psm=psm if psm else None + ), + handler=on_connection, + ) + print(f'Listen L2CAP on channel {server.psm}') + + if peer_address: + connection = await device.connect( + peer_address, transport=core.PhysicalTransport.BR_EDR + ) + channel = await connection.create_l2cap_channel( + spec=l2cap.ClassicChannelSpec( + mode=l2cap.TransmissionMode(mode), psm=psm + ) + ) + active_channel = channel + + while sdu := await asyncio.to_thread(lambda: input('>>> ')): + if active_channel: + active_channel.write(sdu.encode()) + + await hci_transport.source.terminated + + +# ----------------------------------------------------------------------------- +bumble.logging.setup_basic_logging('INFO') +parser = argparse.ArgumentParser() +parser.add_argument('config') +parser.add_argument('transport') +parser.add_argument('-p', '--peer_address', default='') +parser.add_argument( + '-m', '--mode', default=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION +) +parser.add_argument('--psm', default=0) +args = parser.parse_args(sys.argv[1:]) +asyncio.run(main(args.config, args.transport, args.mode, args.peer_address, args.psm)) diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index cbeb874..bfc566e 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -19,6 +19,7 @@ import asyncio import logging import os import random +import struct import pytest @@ -342,6 +343,76 @@ async def test_mtu(): assert client_channel.peer_mtu == 345 +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_enhanced_retransmission_mode(): + devices = TwoDevices() + await devices.setup_connection() + + server_channels = asyncio.Queue[l2cap.ClassicChannel]() + server = devices.devices[1].create_l2cap_server( + spec=l2cap.ClassicChannelSpec( + mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION + ), + handler=server_channels.put_nowait, + ) + client_channel = await devices.connections[0].create_l2cap_channel( + spec=l2cap.ClassicChannelSpec( + server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION + ) + ) + server_channel = await server_channels.get() + + sinks = [asyncio.Queue[bytes]() for _ in range(2)] + server_channel.sink = sinks[0].put_nowait + client_channel.sink = sinks[1].put_nowait + + for i in range(128): + server_channel.write(struct.pack('