diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 856b7f39..085e0360 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -213,7 +213,7 @@ class L2CAP_Control_Frame: fields: ClassVar[hci.Fields] = () code: int = dataclasses.field(default=0, init=False) name: str = dataclasses.field(default='', init=False) - _data: Optional[bytes] = dataclasses.field(default=None, init=False) + _payload: Optional[bytes] = dataclasses.field(default=None, init=False) identifier: int @@ -223,7 +223,8 @@ class L2CAP_Control_Frame: subclass = L2CAP_Control_Frame.classes.get(code) if subclass is None: - instance = L2CAP_Control_Frame(pdu) + instance = L2CAP_Control_Frame(identifier=identifier) + instance.payload = pdu[4:] instance.code = CommandCode(code) instance.name = instance.code.name return instance @@ -232,11 +233,11 @@ class L2CAP_Control_Frame: identifier=identifier, ) frame.identifier = identifier - frame.data = pdu[4:] - if length != len(pdu): + frame.payload = pdu[4:] + if length != len(frame.payload): logger.warning( color( - f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', + f'!!! length mismatch: expected {length} but got {len(frame.payload)}', 'red', ) ) @@ -273,34 +274,20 @@ class L2CAP_Control_Frame: return subclass - def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: - self.identifier = kwargs.get('identifier', 0) - if self.fields: - if kwargs: - hci.HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - data = hci.HCI_Object.dict_to_bytes(kwargs, self.fields) - pdu = ( - bytes([self.code, self.identifier]) - + struct.pack(' bytes: - if self._data is None: - self._data = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) - return self._data + def payload(self) -> bytes: + if self._payload is None: + self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) + return self._payload - @data.setter - def data(self, parameters: bytes) -> None: - self._data = parameters + @payload.setter + def payload(self, payload: bytes) -> None: + self._payload = payload def __bytes__(self) -> bytes: return ( - struct.pack(' str: @@ -308,8 +295,8 @@ class L2CAP_Control_Frame: if fields := getattr(self, 'fields', None): result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') else: - if len(self.data) > 1: - result += f': {self.data.hex()}' + if len(self.payload) > 1: + result += f': {self.payload.hex()}' return result diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 3ece29ab..e2c7b29f 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -22,12 +22,8 @@ import random import pytest from bumble.core import ProtocolError -from bumble.l2cap import ( - L2CAP_Connection_Request, - ClassicChannelSpec, - LeCreditBasedChannelSpec, -) -from .test_utils import TwoDevices +from bumble import l2cap +from .test_utils import TwoDevices, async_barrier # ----------------------------------------------------------------------------- @@ -41,42 +37,53 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- def test_helpers(): - psm = L2CAP_Connection_Request.serialize_psm(0x01) + psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x01) assert psm == bytes([0x01, 0x00]) - psm = L2CAP_Connection_Request.serialize_psm(0x1023) + psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x1023) assert psm == bytes([0x23, 0x10]) - psm = L2CAP_Connection_Request.serialize_psm(0x242311) + psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x242311) assert psm == bytes([0x11, 0x23, 0x24]) - (offset, psm) = L2CAP_Connection_Request.parse_psm( + (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm( bytes([0x00, 0x01, 0x00, 0x44]), 1 ) assert offset == 3 assert psm == 0x01 - (offset, psm) = L2CAP_Connection_Request.parse_psm( + (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm( bytes([0x00, 0x23, 0x10, 0x44]), 1 ) assert offset == 3 assert psm == 0x1023 - (offset, psm) = L2CAP_Connection_Request.parse_psm( + (offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm( bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1 ) assert offset == 4 assert psm == 0x242311 - rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88) + rq = l2cap.L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88) brq = bytes(rq) - srq = L2CAP_Connection_Request.from_bytes(brq) - assert isinstance(srq, L2CAP_Connection_Request) + srq = l2cap.L2CAP_Connection_Request.from_bytes(brq) + assert isinstance(srq, l2cap.L2CAP_Connection_Request) assert srq.psm == rq.psm assert srq.source_cid == rq.source_cid assert srq.identifier == rq.identifier +# ----------------------------------------------------------------------------- +def test_unimplemented_control_frame(): + frame = l2cap.L2CAP_Control_Frame(identifier=1) + frame.code = 0xFF + frame.payload = b'123456' + + parsed = l2cap.L2CAP_Control_Frame.from_bytes(bytes(frame)) + assert parsed.code == 0xFF + assert parsed.payload == b'123456' + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_basic_connection(): @@ -87,7 +94,7 @@ async def test_basic_connection(): # Check that if there's no one listening, we can't connect with pytest.raises(ProtocolError): l2cap_channel = await devices.connections[0].create_l2cap_channel( - spec=LeCreditBasedChannelSpec(psm) + spec=l2cap.LeCreditBasedChannelSpec(psm) ) # Now add a listener @@ -104,10 +111,10 @@ async def test_basic_connection(): channel.sink = on_data devices.devices[1].create_l2cap_server( - spec=LeCreditBasedChannelSpec(psm=1234), handler=on_coc + spec=l2cap.LeCreditBasedChannelSpec(psm=1234), handler=on_coc ) l2cap_channel = await devices.connections[0].create_l2cap_channel( - spec=LeCreditBasedChannelSpec(psm) + spec=l2cap.LeCreditBasedChannelSpec(psm) ) messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000)) @@ -137,6 +144,41 @@ async def test_basic_connection(): assert sent_bytes == received_bytes +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType)) +async def test_l2cap_information_request(monkeypatch, info_type): + # TODO: Replace handlers with API when implemented + devices = await TwoDevices.create_with_connection() + + # Register handlers + info_rsp = list[l2cap.L2CAP_Information_Response]() + + def on_l2cap_information_response(connection, cid, frame): + info_rsp.append(frame) + + assert (connection := devices.connections[0]) + channel_manager = devices[0].l2cap_channel_manager + monkeypatch.setattr( + channel_manager, + 'on_l2cap_information_response', + on_l2cap_information_response, + raising=False, + ) + + channel_manager.send_control_frame( + connection, + l2cap.L2CAP_LE_SIGNALING_CID, + l2cap.L2CAP_Information_Request( + identifier=channel_manager.next_identifier(connection), + info_type=info_type, + ), + ) + + await async_barrier() + response = info_rsp[0] + assert response.result == l2cap.L2CAP_Information_Response.Result.SUCCESS + + # ----------------------------------------------------------------------------- async def transfer_payload(max_credits, mtu, mps): devices = TwoDevices() @@ -151,11 +193,11 @@ async def transfer_payload(max_credits, mtu, mps): channel.sink = on_data server = devices.devices[1].create_l2cap_server( - spec=LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), + spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), handler=on_coc, ) l2cap_channel = await devices.connections[0].create_l2cap_channel( - spec=LeCreditBasedChannelSpec(server.psm) + spec=l2cap.LeCreditBasedChannelSpec(server.psm) ) messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] @@ -205,10 +247,10 @@ async def test_bidirectional_transfer(): client_received.append(data) server = devices.devices[1].create_l2cap_server( - spec=LeCreditBasedChannelSpec(), handler=on_server_coc + spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc ) client_channel = await devices.connections[0].create_l2cap_channel( - spec=LeCreditBasedChannelSpec(server.psm) + spec=l2cap.LeCreditBasedChannelSpec(server.psm) ) client_channel.sink = on_client_data @@ -242,10 +284,10 @@ async def test_mtu(): channel.on('open', lambda: on_channel_open(channel)) server = devices.devices[1].create_l2cap_server( - spec=ClassicChannelSpec(mtu=345), handler=on_channel + spec=l2cap.ClassicChannelSpec(mtu=345), handler=on_channel ) client_channel = await devices.connections[0].create_l2cap_channel( - spec=ClassicChannelSpec(server.psm, mtu=456) + spec=l2cap.ClassicChannelSpec(server.psm, mtu=456) ) assert client_channel.peer_mtu == 345