diff --git a/apps/bench.py b/apps/bench.py index 1f9d45f7..83625f00 100644 --- a/apps/bench.py +++ b/apps/bench.py @@ -87,6 +87,7 @@ DEFAULT_LINGER_TIME = 1.0 DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0 DEFAULT_RFCOMM_CHANNEL = 8 +DEFAULT_RFCOMM_MTU = 2048 # ----------------------------------------------------------------------------- @@ -896,11 +897,14 @@ class L2capServer(StreamedPacketIO): # RfcommClient # ----------------------------------------------------------------------------- class RfcommClient(StreamedPacketIO): - def __init__(self, device, channel, uuid): + def __init__(self, device, channel, uuid, l2cap_mtu, max_frame_size, window_size): super().__init__() self.device = device self.channel = channel self.uuid = uuid + self.l2cap_mtu = l2cap_mtu + self.max_frame_size = max_frame_size + self.window_size = window_size self.rfcomm_session = None self.ready = asyncio.Event() @@ -924,13 +928,21 @@ class RfcommClient(StreamedPacketIO): # Create a client and start it logging.info(color('*** Starting RFCOMM client...', 'blue')) - rfcomm_client = bumble.rfcomm.Client(connection) + rfcomm_options = {} + if self.l2cap_mtu: + rfcomm_options['l2cap_mtu'] = self.l2cap_mtu + rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options) rfcomm_mux = await rfcomm_client.start() logging.info(color('*** Started', 'blue')) logging.info(color(f'### Opening session for channel {channel}...', 'yellow')) try: - rfcomm_session = await rfcomm_mux.open_dlc(channel) + dlc_options = {} + if self.max_frame_size: + dlc_options['max_frame_size'] = self.max_frame_size + if self.window_size: + dlc_options['window_size'] = self.window_size + rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options) logging.info(color(f'### Session open: {rfcomm_session}', 'yellow')) except bumble.core.ConnectionError as error: logging.info(color(f'!!! Session open failed: {error}', 'red')) @@ -955,13 +967,16 @@ class RfcommClient(StreamedPacketIO): # RfcommServer # ----------------------------------------------------------------------------- class RfcommServer(StreamedPacketIO): - def __init__(self, device, channel): + def __init__(self, device, channel, l2cap_mtu): super().__init__() self.dlc = None self.ready = asyncio.Event() # Create and register a server - rfcomm_server = bumble.rfcomm.Server(device) + server_options = {} + if l2cap_mtu: + server_options['l2cap_mtu'] = l2cap_mtu + rfcomm_server = bumble.rfcomm.Server(device, **server_options) # Listen for incoming DLC connections channel_number = rfcomm_server.listen(self.on_dlc, channel) @@ -1298,11 +1313,20 @@ def create_mode_factory(ctx, default_mode): if mode == 'rfcomm-client': return RfcommClient( - device, channel=ctx.obj['rfcomm_channel'], uuid=ctx.obj['rfcomm_uuid'] + device, + channel=ctx.obj['rfcomm_channel'], + uuid=ctx.obj['rfcomm_uuid'], + l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], + max_frame_size=ctx.obj['rfcomm_max_frame_size'], + window_size=ctx.obj['rfcomm_window_size'], ) if mode == 'rfcomm-server': - return RfcommServer(device, channel=ctx.obj['rfcomm_channel']) + return RfcommServer( + device, + channel=ctx.obj['rfcomm_channel'], + l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], + ) raise ValueError('invalid mode') @@ -1389,6 +1413,21 @@ def create_role_factory(ctx, default_role): default=DEFAULT_RFCOMM_UUID, help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)', ) +@click.option( + '--rfcomm-l2cap-mtu', + type=int, + help='RFComm L2CAP MTU', +) +@click.option( + '--rfcomm-max-frame-size', + type=int, + help='RFComm maximum frame size', +) +@click.option( + '--rfcomm-window-size', + type=int, + help='RFComm window size', +) @click.option( '--l2cap-psm', type=int, @@ -1486,6 +1525,9 @@ def bench( linger, rfcomm_channel, rfcomm_uuid, + rfcomm_l2cap_mtu, + rfcomm_max_frame_size, + rfcomm_window_size, l2cap_psm, l2cap_mtu, l2cap_mps, @@ -1498,6 +1540,9 @@ def bench( ctx.obj['att_mtu'] = att_mtu ctx.obj['rfcomm_channel'] = rfcomm_channel ctx.obj['rfcomm_uuid'] = rfcomm_uuid + ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu + ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size + ctx.obj['rfcomm_window_size'] = rfcomm_window_size ctx.obj['l2cap_psm'] = l2cap_psm ctx.obj['l2cap_mtu'] = l2cap_mtu ctx.obj['l2cap_mps'] = l2cap_mps diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 3be1e157..f7851099 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -1470,10 +1470,10 @@ class Protocol(EventEmitter): f'[{transaction_label}] {message}' ) max_fragment_size = ( - self.l2cap_channel.mtu - 3 + self.l2cap_channel.peer_mtu - 3 ) # Enough space for a 3-byte start packet header payload = message.payload - if len(payload) + 2 <= self.l2cap_channel.mtu: + if len(payload) + 2 <= self.l2cap_channel.peer_mtu: # Fits in a single packet packet_type = self.PacketType.SINGLE_PACKET else: diff --git a/bumble/hid.py b/bumble/hid.py index 5ea9b98a..fc5c8074 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -416,7 +416,7 @@ class Device(HID): data = bytearray() data.append(report_id) data.extend(ret.data) - if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr] + if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr] self.send_control_data(report_type=report_type, data=data) else: self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index f91a269f..cec14b85 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -173,7 +173,7 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01 @dataclasses.dataclass class ClassicChannelSpec: psm: Optional[int] = None - mtu: int = L2CAP_MIN_BR_EDR_MTU + mtu: int = L2CAP_DEFAULT_MTU @dataclasses.dataclass @@ -749,6 +749,8 @@ class ClassicChannel(EventEmitter): sink: Optional[Callable[[bytes], Any]] state: State connection: Connection + mtu: int + peer_mtu: int def __init__( self, @@ -765,6 +767,7 @@ class ClassicChannel(EventEmitter): self.signaling_cid = signaling_cid self.state = self.State.CLOSED self.mtu = mtu + self.peer_mtu = L2CAP_MIN_BR_EDR_MTU self.psm = psm self.source_cid = source_cid self.destination_cid = 0 @@ -861,7 +864,7 @@ class ClassicChannel(EventEmitter): [ ( L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE, - struct.pack('{self.destination_cid}, ' f'PSM={self.psm}, ' - f'MTU={self.mtu}, ' + f'MTU={self.mtu}/{self.peer_mtu}, ' f'state={self.state.name})' ) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 5500bc12..6ca0f509 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -104,6 +104,7 @@ CRC_TABLE = bytes([ 0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF ]) +RFCOMM_DEFAULT_L2CAP_MTU = 2048 RFCOMM_DEFAULT_WINDOW_SIZE = 7 RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000 @@ -473,7 +474,7 @@ class DLC(EventEmitter): # Compute the MTU max_overhead = 4 + 1 # header with 2-byte length + fcs self.mtu = min( - max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead + max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead ) def change_state(self, new_state: State) -> None: @@ -908,8 +909,11 @@ class Client: multiplexer: Optional[Multiplexer] l2cap_channel: Optional[l2cap.ClassicChannel] - def __init__(self, connection: Connection) -> None: + def __init__( + self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU + ) -> None: self.connection = connection + self.l2cap_mtu = l2cap_mtu self.l2cap_channel = None self.multiplexer = None @@ -917,7 +921,7 @@ class Client: # Create a new L2CAP connection try: self.l2cap_channel = await self.connection.create_l2cap_channel( - spec=l2cap.ClassicChannelSpec(RFCOMM_PSM) + spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu) ) except ProtocolError as error: logger.warning(f'L2CAP connection failed: {error}') @@ -955,7 +959,9 @@ class Client: class Server(EventEmitter): acceptors: Dict[int, Callable[[DLC], None]] - def __init__(self, device: Device) -> None: + def __init__( + self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU + ) -> None: super().__init__() self.device = device self.multiplexer = None @@ -963,7 +969,8 @@ class Server(EventEmitter): # Register ourselves with the L2CAP channel manager self.l2cap_server = device.create_l2cap_server( - spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection + spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu), + handler=self.on_connection, ) def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int: diff --git a/examples/run_a2dp_source.py b/examples/run_a2dp_source.py index 92812fe1..46452293 100644 --- a/examples/run_a2dp_source.py +++ b/examples/run_a2dp_source.py @@ -74,7 +74,7 @@ def codec_capabilities(): # ----------------------------------------------------------------------------- def on_avdtp_connection(read_function, protocol): packet_source = SbcPacketSource( - read_function, protocol.l2cap_channel.mtu, codec_capabilities() + read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() ) packet_pump = MediaPacketPump(packet_source.packets) protocol.add_source(packet_source.codec_capabilities, packet_pump) @@ -98,7 +98,7 @@ async def stream_packets(read_function, protocol): # Stream the packets packet_source = SbcPacketSource( - read_function, protocol.l2cap_channel.mtu, codec_capabilities() + read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() ) packet_pump = MediaPacketPump(packet_source.packets) source = protocol.add_source(packet_source.codec_capabilities, packet_pump) diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 5cb285c3..6323ddfa 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -227,12 +227,34 @@ async def test_bidirectional_transfer(): assert server_received_bytes == message_bytes +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_mtu(): + devices = TwoDevices() + await devices.setup_connection() + + def on_channel_open(channel): + assert channel.peer_mtu == 456 + + def on_channel(channel): + channel.on('open', lambda: on_channel_open(channel)) + + server = devices.devices[1].create_l2cap_server( + spec=ClassicChannelSpec(mtu=345), handler=on_channel + ) + client_channel = await devices.connections[0].create_l2cap_channel( + spec=ClassicChannelSpec(server.psm, mtu=456) + ) + assert client_channel.peer_mtu == 345 + + # ----------------------------------------------------------------------------- async def run(): test_helpers() await test_basic_connection() await test_transfer() await test_bidirectional_transfer() + await test_mtu() # -----------------------------------------------------------------------------