Fix L2CAP errors

This commit is contained in:
Josh Wu
2025-07-29 22:34:43 +08:00
parent 0665e9ca5c
commit 822f97fa84
2 changed files with 83 additions and 54 deletions

View File

@@ -213,7 +213,7 @@ class L2CAP_Control_Frame:
fields: ClassVar[hci.Fields] = () fields: ClassVar[hci.Fields] = ()
code: int = dataclasses.field(default=0, init=False) code: int = dataclasses.field(default=0, init=False)
name: str = dataclasses.field(default='', init=False) name: str = dataclasses.field(default='', init=False)
_data: Optional[bytes] = dataclasses.field(default=None, init=False) _payload: Optional[bytes] = dataclasses.field(default=None, init=False)
identifier: int identifier: int
@@ -223,7 +223,8 @@ class L2CAP_Control_Frame:
subclass = L2CAP_Control_Frame.classes.get(code) subclass = L2CAP_Control_Frame.classes.get(code)
if subclass is None: if subclass is None:
instance = L2CAP_Control_Frame(pdu) instance = L2CAP_Control_Frame(identifier=identifier)
instance.payload = pdu[4:]
instance.code = CommandCode(code) instance.code = CommandCode(code)
instance.name = instance.code.name instance.name = instance.code.name
return instance return instance
@@ -232,11 +233,11 @@ class L2CAP_Control_Frame:
identifier=identifier, identifier=identifier,
) )
frame.identifier = identifier frame.identifier = identifier
frame.data = pdu[4:] frame.payload = pdu[4:]
if length != len(pdu): if length != len(frame.payload):
logger.warning( logger.warning(
color( color(
f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', f'!!! length mismatch: expected {length} but got {len(frame.payload)}',
'red', 'red',
) )
) )
@@ -273,34 +274,20 @@ class L2CAP_Control_Frame:
return subclass return subclass
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
self.identifier = kwargs.get('identifier', 0)
if self.fields:
if kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = hci.HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
self.data = pdu[4:] if pdu else b''
@property @property
def data(self) -> bytes: def payload(self) -> bytes:
if self._data is None: if self._payload is None:
self._data = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._data return self._payload
@data.setter @payload.setter
def data(self, parameters: bytes) -> None: def payload(self, payload: bytes) -> None:
self._data = parameters self._payload = payload
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return ( return (
struct.pack('<BBH', self.code, self.identifier, len(self.data) + 4) struct.pack('<BBH', self.code, self.identifier, len(self.payload))
+ self.data + self.payload
) )
def __str__(self) -> str: def __str__(self) -> str:
@@ -308,8 +295,8 @@ class L2CAP_Control_Frame:
if fields := getattr(self, 'fields', None): if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else: else:
if len(self.data) > 1: if len(self.payload) > 1:
result += f': {self.data.hex()}' result += f': {self.payload.hex()}'
return result return result

View File

@@ -22,12 +22,8 @@ import random
import pytest import pytest
from bumble.core import ProtocolError from bumble.core import ProtocolError
from bumble.l2cap import ( from bumble import l2cap
L2CAP_Connection_Request, from .test_utils import TwoDevices, async_barrier
ClassicChannelSpec,
LeCreditBasedChannelSpec,
)
from .test_utils import TwoDevices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -41,42 +37,53 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_helpers(): def test_helpers():
psm = L2CAP_Connection_Request.serialize_psm(0x01) psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x01)
assert psm == bytes([0x01, 0x00]) assert psm == bytes([0x01, 0x00])
psm = L2CAP_Connection_Request.serialize_psm(0x1023) psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x1023)
assert psm == bytes([0x23, 0x10]) assert psm == bytes([0x23, 0x10])
psm = L2CAP_Connection_Request.serialize_psm(0x242311) psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x242311)
assert psm == bytes([0x11, 0x23, 0x24]) assert psm == bytes([0x11, 0x23, 0x24])
(offset, psm) = L2CAP_Connection_Request.parse_psm( (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x01, 0x00, 0x44]), 1 bytes([0x00, 0x01, 0x00, 0x44]), 1
) )
assert offset == 3 assert offset == 3
assert psm == 0x01 assert psm == 0x01
(offset, psm) = L2CAP_Connection_Request.parse_psm( (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x23, 0x10, 0x44]), 1 bytes([0x00, 0x23, 0x10, 0x44]), 1
) )
assert offset == 3 assert offset == 3
assert psm == 0x1023 assert psm == 0x1023
(offset, psm) = L2CAP_Connection_Request.parse_psm( (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1 bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
) )
assert offset == 4 assert offset == 4
assert psm == 0x242311 assert psm == 0x242311
rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88) rq = l2cap.L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88)
brq = bytes(rq) brq = bytes(rq)
srq = L2CAP_Connection_Request.from_bytes(brq) srq = l2cap.L2CAP_Connection_Request.from_bytes(brq)
assert isinstance(srq, L2CAP_Connection_Request) assert isinstance(srq, l2cap.L2CAP_Connection_Request)
assert srq.psm == rq.psm assert srq.psm == rq.psm
assert srq.source_cid == rq.source_cid assert srq.source_cid == rq.source_cid
assert srq.identifier == rq.identifier assert srq.identifier == rq.identifier
# -----------------------------------------------------------------------------
def test_unimplemented_control_frame():
frame = l2cap.L2CAP_Control_Frame(identifier=1)
frame.code = 0xFF
frame.payload = b'123456'
parsed = l2cap.L2CAP_Control_Frame.from_bytes(bytes(frame))
assert parsed.code == 0xFF
assert parsed.payload == b'123456'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_connection(): async def test_basic_connection():
@@ -87,7 +94,7 @@ async def test_basic_connection():
# Check that if there's no one listening, we can't connect # Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError): with pytest.raises(ProtocolError):
l2cap_channel = await devices.connections[0].create_l2cap_channel( l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(psm) spec=l2cap.LeCreditBasedChannelSpec(psm)
) )
# Now add a listener # Now add a listener
@@ -104,10 +111,10 @@ async def test_basic_connection():
channel.sink = on_data channel.sink = on_data
devices.devices[1].create_l2cap_server( devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(psm=1234), handler=on_coc spec=l2cap.LeCreditBasedChannelSpec(psm=1234), handler=on_coc
) )
l2cap_channel = await devices.connections[0].create_l2cap_channel( l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(psm) spec=l2cap.LeCreditBasedChannelSpec(psm)
) )
messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000)) messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000))
@@ -137,6 +144,41 @@ async def test_basic_connection():
assert sent_bytes == received_bytes assert sent_bytes == received_bytes
# -----------------------------------------------------------------------------
@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType))
async def test_l2cap_information_request(monkeypatch, info_type):
# TODO: Replace handlers with API when implemented
devices = await TwoDevices.create_with_connection()
# Register handlers
info_rsp = list[l2cap.L2CAP_Information_Response]()
def on_l2cap_information_response(connection, cid, frame):
info_rsp.append(frame)
assert (connection := devices.connections[0])
channel_manager = devices[0].l2cap_channel_manager
monkeypatch.setattr(
channel_manager,
'on_l2cap_information_response',
on_l2cap_information_response,
raising=False,
)
channel_manager.send_control_frame(
connection,
l2cap.L2CAP_LE_SIGNALING_CID,
l2cap.L2CAP_Information_Request(
identifier=channel_manager.next_identifier(connection),
info_type=info_type,
),
)
await async_barrier()
response = info_rsp[0]
assert response.result == l2cap.L2CAP_Information_Response.Result.SUCCESS
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps): async def transfer_payload(max_credits, mtu, mps):
devices = TwoDevices() devices = TwoDevices()
@@ -151,11 +193,11 @@ async def transfer_payload(max_credits, mtu, mps):
channel.sink = on_data channel.sink = on_data
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=on_coc, handler=on_coc,
) )
l2cap_channel = await devices.connections[0].create_l2cap_channel( l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(server.psm) spec=l2cap.LeCreditBasedChannelSpec(server.psm)
) )
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
@@ -205,10 +247,10 @@ async def test_bidirectional_transfer():
client_received.append(data) client_received.append(data)
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(), handler=on_server_coc spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc
) )
client_channel = await devices.connections[0].create_l2cap_channel( client_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(server.psm) spec=l2cap.LeCreditBasedChannelSpec(server.psm)
) )
client_channel.sink = on_client_data client_channel.sink = on_client_data
@@ -242,10 +284,10 @@ async def test_mtu():
channel.on('open', lambda: on_channel_open(channel)) channel.on('open', lambda: on_channel_open(channel))
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=ClassicChannelSpec(mtu=345), handler=on_channel spec=l2cap.ClassicChannelSpec(mtu=345), handler=on_channel
) )
client_channel = await devices.connections[0].create_l2cap_channel( client_channel = await devices.connections[0].create_l2cap_channel(
spec=ClassicChannelSpec(server.psm, mtu=456) spec=l2cap.ClassicChannelSpec(server.psm, mtu=456)
) )
assert client_channel.peer_mtu == 345 assert client_channel.peer_mtu == 345