From ee494a6543c8084cd2d64a934e8c162a634a14d0 Mon Sep 17 00:00:00 2001 From: uael Date: Tue, 7 Nov 2023 00:42:59 -0800 Subject: [PATCH] l2cap: refactor server side to allow deferred accept In order to avoid any breaking changes this re-impl current APIs with the exact same behavior. The previous impl was preventing one to defer the response to an l2cap channel connection request, both for BR/EDR basic channels and LE credit based ones. This commit change this to spawn a task on every channel incoming connection request, then all registered listeners are given a chance to accept it through a `asyncio.Future`. After a bit of delay, if none had accepted it, the connection is automatically rejected. --- bumble/l2cap.py | 598 +++++++++++++++++++++++++++++++----------------- 1 file changed, 389 insertions(+), 209 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 7a2f0ede..f245b3ea 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -35,8 +35,10 @@ from typing import ( Union, Deque, Iterable, + Set, SupportsBytes, TYPE_CHECKING, + overload, ) from .utils import deprecated @@ -237,6 +239,8 @@ class L2CAP_Control_Frame: classes: Dict[int, Type[L2CAP_Control_Frame]] = {} code = 0 name: str + identifier: int + pdu: bytes @staticmethod def from_bytes(pdu: bytes) -> L2CAP_Control_Frame: @@ -391,6 +395,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST ''' + psm: int + source_cid: int + @staticmethod def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: psm_length = 2 @@ -637,7 +644,11 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame): (CODE 0x14) ''' + le_psm: int source_cid: int + mtu: int + mps: int + initial_credits: int # ----------------------------------------------------------------------------- @@ -1375,19 +1386,14 @@ class LeCreditBasedChannel(EventEmitter): # ----------------------------------------------------------------------------- +@dataclasses.dataclass class ClassicChannelServer(EventEmitter): - def __init__( - self, - manager: ChannelManager, - psm: int, - handler: Optional[Callable[[ClassicChannel], Any]], - mtu: int, - ) -> None: + _close_closure: Callable[[], None] + psm: int + handler: Optional[Callable[[ClassicChannel], Any]] + + def __post_init__(self) -> None: super().__init__() - self.manager = manager - self.handler = handler - self.psm = psm - self.mtu = mtu def on_connection(self, channel: ClassicChannel) -> None: self.emit('connection', channel) @@ -1395,28 +1401,18 @@ class ClassicChannelServer(EventEmitter): self.handler(channel) def close(self) -> None: - if self.psm in self.manager.servers: - del self.manager.servers[self.psm] + self._close_closure() # ----------------------------------------------------------------------------- +@dataclasses.dataclass class LeCreditBasedChannelServer(EventEmitter): - def __init__( - self, - manager: ChannelManager, - psm: int, - handler: Optional[Callable[[LeCreditBasedChannel], Any]], - max_credits: int, - mtu: int, - mps: int, - ) -> None: + _close_closure: Callable[[], None] + psm: int + handler: Optional[Callable[[LeCreditBasedChannel], Any]] + + def __post_init__(self) -> None: super().__init__() - self.manager = manager - self.handler = handler - self.psm = psm - self.max_credits = max_credits - self.mtu = mtu - self.mps = mps def on_connection(self, channel: LeCreditBasedChannel) -> None: self.emit('connection', channel) @@ -1424,21 +1420,107 @@ class LeCreditBasedChannelServer(EventEmitter): self.handler(channel) def close(self) -> None: - if self.psm in self.manager.le_coc_servers: - del self.manager.le_coc_servers[self.psm] + self._close_closure() + + +# ----------------------------------------------------------------------------- +class PendingConnection: + """ + All pending connection types. + A `PendingConnection` is a temporary object used to accept an incoming connection + request, it contains the acceptor channel configuration preferences and transition + to the connected state through the `on_connection` callback. + This object is not supposed to live anymore once the channel is connected. + """ + + class Any: + """L2CAP any channel pending connection.""" + + on_connection: Callable[[Any], None] + mtu: int + + @dataclasses.dataclass + class Basic(Any): + """L2CAP basic channel pending connection.""" + + on_connection: Callable[[ClassicChannel], None] = lambda _: None + mtu: int = L2CAP_MIN_BR_EDR_MTU + + @dataclasses.dataclass + class LeCreditBased(Any): + """L2CAP LE credit based channel pending connection.""" + + on_connection: Callable[[LeCreditBasedChannel], None] = lambda _: None + mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU + mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS + + +# ----------------------------------------------------------------------------- +class IncomingConnection: + """ + All incoming connection types. + A `IncomingConnection` is a temporary object used to notify listeners of an + incoming channel connection request. It can accepted through the `future` field. + Multiple listeners can observe the same incoming connection request, but no more + than one can actually accept, first come first served. Thus it's recommended for + delayed accept to before check the state of the future field. + This object is not supposed to live anymore once accepted. + + Example: + ```python + fut = asyncio.Future() + + def listener(incoming: IncomingConnection.Any) -> None: + if isinstance(incoming, IncomingConnection.Basic) and incoming.psm == 0xcafe: + incoming.future.set_result(PendingConnection.Basic(fut.set_result, mtu=123)) + + device.l2cap_manager.listen(listener) + channel = await fut + ``` + """ + + @dataclasses.dataclass + class Any: + """L2CAP any incoming channel connection request.""" + + connection: Connection + psm: int + source_cid: int + + def __post_init__(self) -> None: + self.future: asyncio.Future[Any] = asyncio.Future() + + @dataclasses.dataclass + class Basic(Any): + """L2CAP incoming basic channel connection request.""" + + future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False) + + @dataclasses.dataclass + class LeCreditBased(Any): + """L2CAP incoming LE credit based channel connection request.""" + + mtu: int + mps: int + initial_credits: int + + future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field( + init=False + ) # ----------------------------------------------------------------------------- class ChannelManager: identifiers: Dict[int, int] channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]] - servers: Dict[int, ClassicChannelServer] le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]] - le_coc_servers: Dict[int, LeCreditBasedChannelServer] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] _host: Optional[Host] connection_parameters_update_response: Optional[asyncio.Future[int]] + listeners: List[Callable[[IncomingConnection.Any], None]] + used_psm: Set[int] def __init__( self, @@ -1452,15 +1534,15 @@ class ChannelManager: L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None, } - self.servers = {} # Servers accepting connections, by PSM self.le_coc_channels = ( {} ) # LE CoC channels, mapped by connection and destination cid - self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM self.le_coc_requests = {} # LE CoC connection requests, by identifier self.extended_features = extended_features self.connectionless_mtu = connectionless_mtu self.connection_parameters_update_response = None + self.listeners = [] + self.used_psm = set() @property def host(self) -> Host: @@ -1513,6 +1595,31 @@ class ChannelManager: raise RuntimeError('no free CID') + def allocate_psm(self) -> int: + # Find a free PSM + for candidate in range( + L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 + ): + if (candidate >> 8) % 2 == 1: + continue + if candidate in self.used_psm: + continue + return candidate + raise InvalidStateError('no free PSM') + + def allocate_spsm(self) -> int: + # Find a free sPSM + for candidate in range( + L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 + ): + if candidate in self.used_psm: + continue + return candidate + raise InvalidStateError('no free PSM') + + def free_psm(self, psm: int) -> None: + self.used_psm.remove(psm) + def next_identifier(self, connection: Connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier @@ -1527,6 +1634,35 @@ class ChannelManager: if cid in self.fixed_channels: del self.fixed_channels[cid] + @overload + def listen( + self, cb: Callable[[IncomingConnection.Basic], None] + ) -> Callable[[IncomingConnection.Basic], None]: + ... + + @overload + def listen( + self, cb: Callable[[IncomingConnection.LeCreditBased], None] + ) -> Callable[[IncomingConnection.LeCreditBased], None]: + ... + + def listen(self, cb: Any) -> Any: + if cb in self.listeners: + raise ValueError('listener already registered') + self.listeners.append(cb) + return cb + + @overload + def unlisten(self, cb: Callable[[IncomingConnection.Basic], None]) -> None: + ... + + @overload + def unlisten(self, cb: Callable[[IncomingConnection.LeCreditBased], None]) -> None: + ... + + def unlisten(self, cb: Any) -> None: + self.listeners.remove(cb) + @deprecated("Please use create_classic_server") def register_server( self, @@ -1534,7 +1670,7 @@ class ChannelManager: server: Callable[[ClassicChannel], Any], ) -> int: return self.create_classic_server( - handler=server, spec=ClassicChannelSpec(psm=psm) + handler=server, spec=ClassicChannelSpec(psm=None if psm == 0 else psm) ).psm def create_classic_server( @@ -1542,24 +1678,12 @@ class ChannelManager: spec: ClassicChannelSpec, handler: Optional[Callable[[ClassicChannel], Any]] = None, ) -> ClassicChannelServer: - if not spec.psm: - # Find a free PSM - for candidate in range( - L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 - ): - if (candidate >> 8) % 2 == 1: - continue - if candidate in self.servers: - continue - spec.psm = candidate - break - else: - raise InvalidStateError('no free PSM') + server: ClassicChannelServer + if spec.psm is None: + spec.psm = self.allocate_psm() else: - # Check that the PSM isn't already in use - if spec.psm in self.servers: - raise ValueError('PSM already in use') - + if spec.psm is self.used_psm: + raise ValueError(f'{spec.psm}: PSM already in use') # Check that the PSM is valid if spec.psm % 2 == 0: raise ValueError('invalid PSM (not odd)') @@ -1568,10 +1692,22 @@ class ChannelManager: if check % 2 != 0: raise ValueError('invalid PSM') check >>= 8 + self.used_psm.add(spec.psm) - self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) + def listener(incoming: IncomingConnection.Basic) -> None: + if incoming.psm == spec.psm: + incoming.future.set_result( + PendingConnection.Basic(server.on_connection, spec.mtu) + ) - return self.servers[spec.psm] + def close() -> None: + self.unlisten(listener) + assert spec.psm is not None + self.free_psm(spec.psm) + + self.listen(listener) + server = ClassicChannelServer(close, spec.psm, handler) + return server @deprecated("Please use create_le_credit_based_server()") def register_le_coc_server( @@ -1594,32 +1730,30 @@ class ChannelManager: spec: LeCreditBasedChannelSpec, handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None, ) -> LeCreditBasedChannelServer: - if not spec.psm: - # Find a free PSM - for candidate in range( - L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 - ): - if candidate in self.le_coc_servers: - continue - spec.psm = candidate - break - else: - raise InvalidStateError('no free PSM') + server: LeCreditBasedChannelServer + if spec.psm is None: + spec.psm = self.allocate_psm() else: - # Check that the PSM isn't already in use - if spec.psm in self.le_coc_servers: - raise ValueError('PSM already in use') + if spec.psm is self.used_psm: + raise ValueError(f'{spec.psm}: SPSM already in use') + self.used_psm.add(spec.psm) - self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( - self, - spec.psm, - handler, - max_credits=spec.max_credits, - mtu=spec.mtu, - mps=spec.mps, - ) + def listener(incoming: IncomingConnection.LeCreditBased) -> None: + if incoming.psm == spec.psm: + incoming.future.set_result( + PendingConnection.LeCreditBased( + server.on_connection, spec.mtu, spec.mps, spec.max_credits + ) + ) - return self.le_coc_servers[spec.psm] + def close() -> None: + self.unlisten(listener) + assert spec.psm is not None + self.free_psm(spec.psm) + + self.listen(listener) + server = LeCreditBasedChannelServer(close, spec.psm, handler) + return server def on_disconnection(self, connection_handle: int, _reason: int) -> None: logger.debug(f'disconnection from {connection_handle}, cleaning up channels') @@ -1719,15 +1853,62 @@ class ChannelManager: logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') def on_l2cap_connection_request( - self, connection: Connection, cid: int, request + self, connection: Connection, cid: int, request: L2CAP_Connection_Request ) -> None: - # Check if there's a server for this PSM - server = self.servers.get(request.psm) - if server: - # Find a free CID for this new channel - connection_channels = self.channels.setdefault(connection.handle, {}) - source_cid = self.find_free_br_edr_cid(connection_channels) - if source_cid is None: # Should never happen! + + # Asynchronous connection request handling. + async def handle_connection_request() -> None: + incoming = IncomingConnection.Basic( + connection, request.psm, request.source_cid + ) + + # Dispatch incoming connection. + for listener in self.listeners: + if not incoming.future.done(): + listener(incoming) + + try: + pending = await asyncio.wait_for(incoming.future, timeout=3.0) + except asyncio.TimeoutError as e: + incoming.future.cancel(e) + pending = None + + if pending: + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_br_edr_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_Connection_Response( + identifier=request.identifier, + destination_cid=request.source_cid, + source_cid=0, + # pylint: disable=line-too-long + result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + status=0x0000, + ), + ) + return + + # Create a new channel + logger.debug( + f'creating server channel with cid={source_cid} for psm {request.psm}' + ) + channel = ClassicChannel( + self, connection, cid, request.psm, source_cid, pending.mtu + ) + connection_channels[source_cid] = channel + + # Notify + pending.on_connection(channel) + channel.on_connection_request(request) + else: + logger.warning( + f'No server for connection 0x{connection.handle:04X} ' + f'on PSM {request.psm}' + ) self.send_control_frame( connection, cid, @@ -1736,41 +1917,13 @@ class ChannelManager: destination_cid=request.source_cid, source_cid=0, # pylint: disable=line-too-long - result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, status=0x0000, ), ) - return - # Create a new channel - logger.debug( - f'creating server channel with cid={source_cid} for psm {request.psm}' - ) - channel = ClassicChannel( - self, connection, cid, request.psm, source_cid, server.mtu - ) - connection_channels[source_cid] = channel - - # Notify - server.on_connection(channel) - channel.on_connection_request(request) - else: - logger.warning( - f'No server for connection 0x{connection.handle:04X} ' - f'on PSM {request.psm}' - ) - self.send_control_frame( - connection, - cid, - L2CAP_Connection_Response( - identifier=request.identifier, - destination_cid=request.source_cid, - source_cid=0, - # pylint: disable=line-too-long - result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, - status=0x0000, - ), - ) + # Spawn connection request handling. + connection.abort_on('disconnection', handle_connection_request()) def on_l2cap_connection_response( self, connection: Connection, cid: int, response @@ -1971,108 +2124,135 @@ class ChannelManager: ) def on_l2cap_le_credit_based_connection_request( - self, connection: Connection, cid: int, request + self, + connection: Connection, + cid: int, + request: L2CAP_LE_Credit_Based_Connection_Request, ) -> None: - if request.le_psm in self.le_coc_servers: - server = self.le_coc_servers[request.le_psm] - # Check that the CID isn't already used - le_connection_channels = self.le_coc_channels.setdefault( - connection.handle, {} - ) - if request.source_cid in le_connection_channels: - logger.warning(f'source CID {request.source_cid} already in use') - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=server.mtu, - mps=server.mps, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, - ), - ) - return - - # Find a free CID for this new channel - connection_channels = self.channels.setdefault(connection.handle, {}) - source_cid = self.find_free_le_cid(connection_channels) - if source_cid is None: # Should never happen! - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=server.mtu, - mps=server.mps, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, - ), - ) - return - - # Create a new channel - logger.debug( - f'creating LE CoC server channel with cid={source_cid} for psm ' - f'{request.le_psm}' - ) - channel = LeCreditBasedChannel( - self, + # Asynchronous connection request handling. + async def handle_connection_request() -> None: + incoming = IncomingConnection.LeCreditBased( connection, request.le_psm, - source_cid, request.source_cid, - server.mtu, - server.mps, - request.initial_credits, request.mtu, request.mps, - server.max_credits, - True, - ) - connection_channels[source_cid] = channel - le_connection_channels[request.source_cid] = channel - - # Respond - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=source_cid, - mtu=server.mtu, - mps=server.mps, - initial_credits=server.max_credits, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, - ), + request.initial_credits, ) - # Notify - server.on_connection(channel) - else: - logger.info( - f'No LE server for connection 0x{connection.handle:04X} ' - f'on PSM {request.le_psm}' - ) - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, - ), - ) + # Dispatch incoming connection. + for listener in self.listeners: + if not incoming.future.done(): + listener(incoming) + + try: + pending = await asyncio.wait_for(incoming.future, timeout=3.0) + except asyncio.TimeoutError as e: + incoming.future.cancel(e) + pending = None + + if pending: + # Check that the CID isn't already used + le_connection_channels = self.le_coc_channels.setdefault( + connection.handle, {} + ) + if request.source_cid in le_connection_channels: + logger.warning(f'source CID {request.source_cid} already in use') + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=0, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + ), + ) + return + + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_le_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=0, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + ), + ) + return + + # Create a new channel + logger.debug( + f'creating LE CoC server channel with cid={source_cid} for psm ' + f'{request.le_psm}' + ) + channel = LeCreditBasedChannel( + self, + connection, + request.le_psm, + source_cid, + request.source_cid, + pending.mtu, + pending.mps, + request.initial_credits, + request.mtu, + request.mps, + pending.max_credits, + True, + ) + connection_channels[source_cid] = channel + le_connection_channels[request.source_cid] = channel + + # Respond + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=source_cid, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=pending.max_credits, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, + ), + ) + + # Notify + pending.on_connection(channel) + else: + logger.info( + f'No LE server for connection 0x{connection.handle:04X} ' + f'on PSM {request.le_psm}' + ) + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + initial_credits=0, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, + ), + ) + + # Spawn connection request handling. + connection.abort_on('disconnection', handle_connection_request()) def on_l2cap_le_credit_based_connection_response( self, connection: Connection, _cid: int, response