From 190529184edd3c404c022bca7d0afe8b5a5381bc Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 27 Jul 2023 15:31:08 +0800 Subject: [PATCH] L2CAP: Import device.Connection for typing --- bumble/l2cap.py | 78 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 119cb13..7abde0d 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -33,6 +33,7 @@ from typing import ( Union, Deque, Iterable, + TYPE_CHECKING, ) from .colors import color @@ -44,6 +45,9 @@ from .hci import ( name_or_number, ) +if TYPE_CHECKING: + from bumble.device import Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -720,11 +724,12 @@ class Channel(EventEmitter): response: Optional[asyncio.Future[bytes]] sink: Optional[Callable[[bytes], Any]] state: int + connection: 'Connection' def __init__( self, manager: 'ChannelManager', - connection, + connection: 'Connection', signaling_cid: int, psm: int, source_cid: int, @@ -1028,6 +1033,7 @@ class LeConnectionOrientedChannel(EventEmitter): disconnection_result: Optional[asyncio.Future[None]] out_sdu: Optional[bytes] state: int + connection: 'Connection' @staticmethod def state_name(state: int) -> str: @@ -1036,7 +1042,7 @@ class LeConnectionOrientedChannel(EventEmitter): def __init__( self, manager: 'ChannelManager', - connection, + connection: 'Connection', le_psm: int, source_cid: int, destination_cid: int, @@ -1465,7 +1471,7 @@ class ChannelManager: ): raise ValueError('MPS out of range') - def next_identifier(self, connection) -> int: + def next_identifier(self, connection: 'Connection') -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier return identifier @@ -1568,7 +1574,7 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) - def on_pdu(self, connection, cid: int, pdu) -> None: + def on_pdu(self, connection: 'Connection', cid: int, pdu) -> None: if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): # Parse the L2CAP payload into a Control Frame object control_frame = L2CAP_Control_Frame.from_bytes(pdu) @@ -1589,7 +1595,9 @@ class ChannelManager: channel.on_pdu(pdu) - def send_control_frame(self, connection, cid: int, control_frame) -> None: + def send_control_frame( + self, connection: 'Connection', cid: int, control_frame + ) -> None: logger.debug( f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1597,7 +1605,9 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) - def on_control_frame(self, connection, cid: int, control_frame) -> None: + def on_control_frame( + self, connection: 'Connection', cid: int, control_frame + ) -> None: logger.debug( f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1634,10 +1644,14 @@ class ChannelManager: ), ) - def on_l2cap_command_reject(self, _connection, _cid: int, packet) -> None: + def on_l2cap_command_reject( + self, _connection: 'Connection', _cid: int, packet + ) -> None: logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') - def on_l2cap_connection_request(self, connection, cid: int, request) -> None: + def on_l2cap_connection_request( + self, connection: 'Connection', cid: int, request + ) -> None: # Check if there's a server for this PSM server = self.servers.get(request.psm) if server: @@ -1689,7 +1703,9 @@ class ChannelManager: ), ) - def on_l2cap_connection_response(self, connection, cid: int, response) -> None: + def on_l2cap_connection_response( + self, connection: 'Connection', cid: int, response + ) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1704,7 +1720,9 @@ class ChannelManager: channel.on_connection_response(response) - def on_l2cap_configure_request(self, connection, cid: int, request) -> None: + def on_l2cap_configure_request( + self, connection: 'Connection', cid: int, request + ) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1719,7 +1737,9 @@ class ChannelManager: channel.on_configure_request(request) - def on_l2cap_configure_response(self, connection, cid: int, response) -> None: + def on_l2cap_configure_response( + self, connection: 'Connection', cid: int, response + ) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1734,7 +1754,9 @@ class ChannelManager: channel.on_configure_response(response) - def on_l2cap_disconnection_request(self, connection, cid: int, request) -> None: + def on_l2cap_disconnection_request( + self, connection: 'Connection', cid: int, request + ) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1749,7 +1771,9 @@ class ChannelManager: channel.on_disconnection_request(request) - def on_l2cap_disconnection_response(self, connection, cid: int, response) -> None: + def on_l2cap_disconnection_response( + self, connection: 'Connection', cid: int, response + ) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1764,7 +1788,9 @@ class ChannelManager: channel.on_disconnection_response(response) - def on_l2cap_echo_request(self, connection, cid: int, request) -> None: + def on_l2cap_echo_request( + self, connection: 'Connection', cid: int, request + ) -> None: logger.debug(f'<<< Echo request: data={request.data.hex()}') self.send_control_frame( connection, @@ -1772,11 +1798,15 @@ class ChannelManager: L2CAP_Echo_Response(identifier=request.identifier, data=request.data), ) - def on_l2cap_echo_response(self, _connection, _cid: int, response) -> None: + def on_l2cap_echo_response( + self, _connection: 'Connection', _cid: int, response + ) -> None: logger.debug(f'<<< Echo response: data={response.data.hex()}') # TODO notify listeners - def on_l2cap_information_request(self, connection, cid: int, request) -> None: + def on_l2cap_information_request( + self, connection: 'Connection', cid: int, request + ) -> None: if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU: result = L2CAP_Information_Response.SUCCESS data = self.connectionless_mtu.to_bytes(2, 'little') @@ -1801,7 +1831,7 @@ class ChannelManager: ) def on_l2cap_connection_parameter_update_request( - self, connection, cid: int, request + self, connection: 'Connection', cid: int, request ): if connection.role == BT_CENTRAL_ROLE: self.send_control_frame( @@ -1834,13 +1864,13 @@ class ChannelManager: ) def on_l2cap_connection_parameter_update_response( - self, connection, cid: int, response + self, connection: 'Connection', cid: int, response ) -> None: # TODO: check response pass def on_l2cap_le_credit_based_connection_request( - self, connection, cid: int, request + self, connection: 'Connection', cid: int, request ) -> None: if request.le_psm in self.le_coc_servers: (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] @@ -1944,7 +1974,7 @@ class ChannelManager: ) def on_l2cap_le_credit_based_connection_response( - self, connection, _cid: int, response + self, connection: 'Connection', _cid: int, response ) -> None: # Find the pending request by identifier request = self.le_coc_requests.get(response.identifier) @@ -1968,7 +1998,9 @@ class ChannelManager: # Process the response channel.on_connection_response(response) - def on_l2cap_le_flow_control_credit(self, connection, _cid: int, credit) -> None: + def on_l2cap_le_flow_control_credit( + self, connection: 'Connection', _cid: int, credit + ) -> None: channel = self.find_le_coc_channel(connection.handle, credit.cid) if channel is None: logger.warning(f'received credits for an unknown channel (cid={credit.cid}') @@ -1983,7 +2015,7 @@ class ChannelManager: del connection_channels[channel.source_cid] async def open_le_coc( - self, connection, psm: int, max_credits: int, mtu: int, mps: int + self, connection: 'Connection', psm: int, max_credits: int, mtu: int, mps: int ) -> LeConnectionOrientedChannel: self.check_le_coc_parameters(max_credits, mtu, mps) @@ -2025,7 +2057,7 @@ class ChannelManager: return channel - async def connect(self, connection, psm: int) -> Channel: + async def connect(self, connection: 'Connection', psm: int) -> Channel: # NOTE: this implementation hard-codes BR/EDR # Find a free CID for a new channel