add rfcomm options and fix l2cap mtu negotiation

This commit is contained in:
Gilles Boccon-Gibod
2024-02-02 15:17:27 -08:00
parent 6d91e7e79b
commit a877283360
7 changed files with 99 additions and 22 deletions

View File

@@ -87,6 +87,7 @@ DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0 DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
DEFAULT_RFCOMM_CHANNEL = 8 DEFAULT_RFCOMM_CHANNEL = 8
DEFAULT_RFCOMM_MTU = 2048
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -896,11 +897,14 @@ class L2capServer(StreamedPacketIO):
# RfcommClient # RfcommClient
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RfcommClient(StreamedPacketIO): class RfcommClient(StreamedPacketIO):
def __init__(self, device, channel, uuid): def __init__(self, device, channel, uuid, l2cap_mtu, max_frame_size, window_size):
super().__init__() super().__init__()
self.device = device self.device = device
self.channel = channel self.channel = channel
self.uuid = uuid self.uuid = uuid
self.l2cap_mtu = l2cap_mtu
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rfcomm_session = None self.rfcomm_session = None
self.ready = asyncio.Event() self.ready = asyncio.Event()
@@ -924,13 +928,21 @@ class RfcommClient(StreamedPacketIO):
# Create a client and start it # Create a client and start it
logging.info(color('*** Starting RFCOMM client...', 'blue')) logging.info(color('*** Starting RFCOMM client...', 'blue'))
rfcomm_client = bumble.rfcomm.Client(connection) rfcomm_options = {}
if self.l2cap_mtu:
rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
rfcomm_mux = await rfcomm_client.start() rfcomm_mux = await rfcomm_client.start()
logging.info(color('*** Started', 'blue')) logging.info(color('*** Started', 'blue'))
logging.info(color(f'### Opening session for channel {channel}...', 'yellow')) logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
try: try:
rfcomm_session = await rfcomm_mux.open_dlc(channel) dlc_options = {}
if self.max_frame_size:
dlc_options['max_frame_size'] = self.max_frame_size
if self.window_size:
dlc_options['window_size'] = self.window_size
rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
logging.info(color(f'### Session open: {rfcomm_session}', 'yellow')) logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
except bumble.core.ConnectionError as error: except bumble.core.ConnectionError as error:
logging.info(color(f'!!! Session open failed: {error}', 'red')) logging.info(color(f'!!! Session open failed: {error}', 'red'))
@@ -955,13 +967,16 @@ class RfcommClient(StreamedPacketIO):
# RfcommServer # RfcommServer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RfcommServer(StreamedPacketIO): class RfcommServer(StreamedPacketIO):
def __init__(self, device, channel): def __init__(self, device, channel, l2cap_mtu):
super().__init__() super().__init__()
self.dlc = None self.dlc = None
self.ready = asyncio.Event() self.ready = asyncio.Event()
# Create and register a server # Create and register a server
rfcomm_server = bumble.rfcomm.Server(device) server_options = {}
if l2cap_mtu:
server_options['l2cap_mtu'] = l2cap_mtu
rfcomm_server = bumble.rfcomm.Server(device, **server_options)
# Listen for incoming DLC connections # Listen for incoming DLC connections
channel_number = rfcomm_server.listen(self.on_dlc, channel) channel_number = rfcomm_server.listen(self.on_dlc, channel)
@@ -1298,11 +1313,20 @@ def create_mode_factory(ctx, default_mode):
if mode == 'rfcomm-client': if mode == 'rfcomm-client':
return RfcommClient( return RfcommClient(
device, channel=ctx.obj['rfcomm_channel'], uuid=ctx.obj['rfcomm_uuid'] device,
channel=ctx.obj['rfcomm_channel'],
uuid=ctx.obj['rfcomm_uuid'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
max_frame_size=ctx.obj['rfcomm_max_frame_size'],
window_size=ctx.obj['rfcomm_window_size'],
) )
if mode == 'rfcomm-server': if mode == 'rfcomm-server':
return RfcommServer(device, channel=ctx.obj['rfcomm_channel']) return RfcommServer(
device,
channel=ctx.obj['rfcomm_channel'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
)
raise ValueError('invalid mode') raise ValueError('invalid mode')
@@ -1389,6 +1413,21 @@ def create_role_factory(ctx, default_role):
default=DEFAULT_RFCOMM_UUID, default=DEFAULT_RFCOMM_UUID,
help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)', help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
) )
@click.option(
'--rfcomm-l2cap-mtu',
type=int,
help='RFComm L2CAP MTU',
)
@click.option(
'--rfcomm-max-frame-size',
type=int,
help='RFComm maximum frame size',
)
@click.option(
'--rfcomm-window-size',
type=int,
help='RFComm window size',
)
@click.option( @click.option(
'--l2cap-psm', '--l2cap-psm',
type=int, type=int,
@@ -1486,6 +1525,9 @@ def bench(
linger, linger,
rfcomm_channel, rfcomm_channel,
rfcomm_uuid, rfcomm_uuid,
rfcomm_l2cap_mtu,
rfcomm_max_frame_size,
rfcomm_window_size,
l2cap_psm, l2cap_psm,
l2cap_mtu, l2cap_mtu,
l2cap_mps, l2cap_mps,
@@ -1498,6 +1540,9 @@ def bench(
ctx.obj['att_mtu'] = att_mtu ctx.obj['att_mtu'] = att_mtu
ctx.obj['rfcomm_channel'] = rfcomm_channel ctx.obj['rfcomm_channel'] = rfcomm_channel
ctx.obj['rfcomm_uuid'] = rfcomm_uuid ctx.obj['rfcomm_uuid'] = rfcomm_uuid
ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
ctx.obj['rfcomm_window_size'] = rfcomm_window_size
ctx.obj['l2cap_psm'] = l2cap_psm ctx.obj['l2cap_psm'] = l2cap_psm
ctx.obj['l2cap_mtu'] = l2cap_mtu ctx.obj['l2cap_mtu'] = l2cap_mtu
ctx.obj['l2cap_mps'] = l2cap_mps ctx.obj['l2cap_mps'] = l2cap_mps

View File

@@ -1470,10 +1470,10 @@ class Protocol(EventEmitter):
f'[{transaction_label}] {message}' f'[{transaction_label}] {message}'
) )
max_fragment_size = ( max_fragment_size = (
self.l2cap_channel.mtu - 3 self.l2cap_channel.peer_mtu - 3
) # Enough space for a 3-byte start packet header ) # Enough space for a 3-byte start packet header
payload = message.payload payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu: if len(payload) + 2 <= self.l2cap_channel.peer_mtu:
# Fits in a single packet # Fits in a single packet
packet_type = self.PacketType.SINGLE_PACKET packet_type = self.PacketType.SINGLE_PACKET
else: else:

View File

@@ -416,7 +416,7 @@ class Device(HID):
data = bytearray() data = bytearray()
data.append(report_id) data.append(report_id)
data.extend(ret.data) data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr] if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data) self.send_control_data(report_type=report_type, data=data)
else: else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)

View File

@@ -173,7 +173,7 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
@dataclasses.dataclass @dataclasses.dataclass
class ClassicChannelSpec: class ClassicChannelSpec:
psm: Optional[int] = None psm: Optional[int] = None
mtu: int = L2CAP_MIN_BR_EDR_MTU mtu: int = L2CAP_DEFAULT_MTU
@dataclasses.dataclass @dataclasses.dataclass
@@ -749,6 +749,8 @@ class ClassicChannel(EventEmitter):
sink: Optional[Callable[[bytes], Any]] sink: Optional[Callable[[bytes], Any]]
state: State state: State
connection: Connection connection: Connection
mtu: int
peer_mtu: int
def __init__( def __init__(
self, self,
@@ -765,6 +767,7 @@ class ClassicChannel(EventEmitter):
self.signaling_cid = signaling_cid self.signaling_cid = signaling_cid
self.state = self.State.CLOSED self.state = self.State.CLOSED
self.mtu = mtu self.mtu = mtu
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = 0 self.destination_cid = 0
@@ -861,7 +864,7 @@ class ClassicChannel(EventEmitter):
[ [
( (
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE, L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
struct.pack('<H', L2CAP_DEFAULT_MTU), struct.pack('<H', self.mtu),
) )
] ]
) )
@@ -926,8 +929,8 @@ class ClassicChannel(EventEmitter):
options = L2CAP_Control_Frame.decode_configuration_options(request.options) options = L2CAP_Control_Frame.decode_configuration_options(request.options)
for option in options: for option in options:
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE: if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
self.mtu = struct.unpack('<H', option[1])[0] self.peer_mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'MTU = {self.mtu}') logger.debug(f'peer MTU = {self.peer_mtu}')
self.send_control_frame( self.send_control_frame(
L2CAP_Configure_Response( L2CAP_Configure_Response(
@@ -1026,7 +1029,7 @@ class ClassicChannel(EventEmitter):
return ( return (
f'Channel({self.source_cid}->{self.destination_cid}, ' f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, ' f'PSM={self.psm}, '
f'MTU={self.mtu}, ' f'MTU={self.mtu}/{self.peer_mtu}, '
f'state={self.state.name})' f'state={self.state.name})'
) )

View File

@@ -104,6 +104,7 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF 0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
]) ])
RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_WINDOW_SIZE = 7 RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000 RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
@@ -473,7 +474,7 @@ class DLC(EventEmitter):
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min( self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
) )
def change_state(self, new_state: State) -> None: def change_state(self, new_state: State) -> None:
@@ -908,8 +909,11 @@ class Client:
multiplexer: Optional[Multiplexer] multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel] l2cap_channel: Optional[l2cap.ClassicChannel]
def __init__(self, connection: Connection) -> None: def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
self.connection = connection self.connection = connection
self.l2cap_mtu = l2cap_mtu
self.l2cap_channel = None self.l2cap_channel = None
self.multiplexer = None self.multiplexer = None
@@ -917,7 +921,7 @@ class Client:
# Create a new L2CAP connection # Create a new L2CAP connection
try: try:
self.l2cap_channel = await self.connection.create_l2cap_channel( self.l2cap_channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(RFCOMM_PSM) spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
) )
except ProtocolError as error: except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}') logger.warning(f'L2CAP connection failed: {error}')
@@ -955,7 +959,9 @@ class Client:
class Server(EventEmitter): class Server(EventEmitter):
acceptors: Dict[int, Callable[[DLC], None]] acceptors: Dict[int, Callable[[DLC], None]]
def __init__(self, device: Device) -> None: def __init__(
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.multiplexer = None self.multiplexer = None
@@ -963,7 +969,8 @@ class Server(EventEmitter):
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
self.l2cap_server = device.create_l2cap_server( self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
handler=self.on_connection,
) )
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int: def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:

View File

@@ -74,7 +74,7 @@ def codec_capabilities():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol): def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource( packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities() read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
) )
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump) protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -98,7 +98,7 @@ async def stream_packets(read_function, protocol):
# Stream the packets # Stream the packets
packet_source = SbcPacketSource( packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities() read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
) )
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump) source = protocol.add_source(packet_source.codec_capabilities, packet_pump)

View File

@@ -227,12 +227,34 @@ async def test_bidirectional_transfer():
assert server_received_bytes == message_bytes assert server_received_bytes == message_bytes
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu():
devices = TwoDevices()
await devices.setup_connection()
def on_channel_open(channel):
assert channel.peer_mtu == 456
def on_channel(channel):
channel.on('open', lambda: on_channel_open(channel))
server = devices.devices[1].create_l2cap_server(
spec=ClassicChannelSpec(mtu=345), handler=on_channel
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=ClassicChannelSpec(server.psm, mtu=456)
)
assert client_channel.peer_mtu == 345
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(): async def run():
test_helpers() test_helpers()
await test_basic_connection() await test_basic_connection()
await test_transfer() await test_transfer()
await test_bidirectional_transfer() await test_bidirectional_transfer()
await test_mtu()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------