l2cap: refactor server side to allow deferred accept

In order to avoid any breaking changes this re-impl current APIs with
the exact same behavior.

The previous impl was preventing one to defer the response to an l2cap
channel connection request, both for BR/EDR basic channels and LE credit
based ones. This commit change this to spawn a task on every channel
incoming connection request, then all registered listeners are given a
chance to accept it through a `asyncio.Future`. After a bit of delay, if
none had accepted it, the connection is automatically rejected.
This commit is contained in:
uael
2023-11-07 00:42:59 -08:00
parent 2cd4f84800
commit ee494a6543
+389 -209
View File
@@ -35,8 +35,10 @@ from typing import (
Union,
Deque,
Iterable,
Set,
SupportsBytes,
TYPE_CHECKING,
overload,
)
from .utils import deprecated
@@ -237,6 +239,8 @@ class L2CAP_Control_Frame:
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name: str
identifier: int
pdu: bytes
@staticmethod
def from_bytes(pdu: bytes) -> L2CAP_Control_Frame:
@@ -391,6 +395,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
psm: int
source_cid: int
@staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
@@ -637,7 +644,11 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
(CODE 0x14)
'''
le_psm: int
source_cid: int
mtu: int
mps: int
initial_credits: int
# -----------------------------------------------------------------------------
@@ -1375,19 +1386,14 @@ class LeCreditBasedChannel(EventEmitter):
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[ClassicChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
@@ -1395,28 +1401,18 @@ class ClassicChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[LeCreditBasedChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
@@ -1424,21 +1420,107 @@ class LeCreditBasedChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
class PendingConnection:
"""
All pending connection types.
A `PendingConnection` is a temporary object used to accept an incoming connection
request, it contains the acceptor channel configuration preferences and transition
to the connected state through the `on_connection` callback.
This object is not supposed to live anymore once the channel is connected.
"""
class Any:
"""L2CAP any channel pending connection."""
on_connection: Callable[[Any], None]
mtu: int
@dataclasses.dataclass
class Basic(Any):
"""L2CAP basic channel pending connection."""
on_connection: Callable[[ClassicChannel], None] = lambda _: None
mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP LE credit based channel pending connection."""
on_connection: Callable[[LeCreditBasedChannel], None] = lambda _: None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
class IncomingConnection:
"""
All incoming connection types.
A `IncomingConnection` is a temporary object used to notify listeners of an
incoming channel connection request. It can accepted through the `future` field.
Multiple listeners can observe the same incoming connection request, but no more
than one can actually accept, first come first served. Thus it's recommended for
delayed accept to before check the state of the future field.
This object is not supposed to live anymore once accepted.
Example:
```python
fut = asyncio.Future()
def listener(incoming: IncomingConnection.Any) -> None:
if isinstance(incoming, IncomingConnection.Basic) and incoming.psm == 0xcafe:
incoming.future.set_result(PendingConnection.Basic(fut.set_result, mtu=123))
device.l2cap_manager.listen(listener)
channel = await fut
```
"""
@dataclasses.dataclass
class Any:
"""L2CAP any incoming channel connection request."""
connection: Connection
psm: int
source_cid: int
def __post_init__(self) -> None:
self.future: asyncio.Future[Any] = asyncio.Future()
@dataclasses.dataclass
class Basic(Any):
"""L2CAP incoming basic channel connection request."""
future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False)
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP incoming LE credit based channel connection request."""
mtu: int
mps: int
initial_credits: int
future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field(
init=False
)
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]]
listeners: List[Callable[[IncomingConnection.Any], None]]
used_psm: Set[int]
def __init__(
self,
@@ -1452,15 +1534,15 @@ class ChannelManager:
L2CAP_SIGNALING_CID: None,
L2CAP_LE_SIGNALING_CID: None,
}
self.servers = {} # Servers accepting connections, by PSM
self.le_coc_channels = (
{}
) # LE CoC channels, mapped by connection and destination cid
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None
self.listeners = []
self.used_psm = set()
@property
def host(self) -> Host:
@@ -1513,6 +1595,31 @@ class ChannelManager:
raise RuntimeError('no free CID')
def allocate_psm(self) -> int:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def allocate_spsm(self) -> int:
# Find a free sPSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def free_psm(self, psm: int) -> None:
self.used_psm.remove(psm)
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -1527,6 +1634,35 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@overload
def listen(
self, cb: Callable[[IncomingConnection.Basic], None]
) -> Callable[[IncomingConnection.Basic], None]:
...
@overload
def listen(
self, cb: Callable[[IncomingConnection.LeCreditBased], None]
) -> Callable[[IncomingConnection.LeCreditBased], None]:
...
def listen(self, cb: Any) -> Any:
if cb in self.listeners:
raise ValueError('listener already registered')
self.listeners.append(cb)
return cb
@overload
def unlisten(self, cb: Callable[[IncomingConnection.Basic], None]) -> None:
...
@overload
def unlisten(self, cb: Callable[[IncomingConnection.LeCreditBased], None]) -> None:
...
def unlisten(self, cb: Any) -> None:
self.listeners.remove(cb)
@deprecated("Please use create_classic_server")
def register_server(
self,
@@ -1534,7 +1670,7 @@ class ChannelManager:
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
handler=server, spec=ClassicChannelSpec(psm=None if psm == 0 else psm)
).psm
def create_classic_server(
@@ -1542,24 +1678,12 @@ class ChannelManager:
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: ClassicChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
@@ -1568,10 +1692,22 @@ class ChannelManager:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.used_psm.add(spec.psm)
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
def listener(incoming: IncomingConnection.Basic) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.Basic(server.on_connection, spec.mtu)
)
return self.servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = ClassicChannelServer(close, spec.psm, handler)
return server
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
@@ -1594,32 +1730,30 @@ class ChannelManager:
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.le_coc_servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: LeCreditBasedChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: SPSM already in use')
self.used_psm.add(spec.psm)
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
spec.psm,
handler,
max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
)
def listener(incoming: IncomingConnection.LeCreditBased) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.LeCreditBased(
server.on_connection, spec.mtu, spec.mps, spec.max_credits
)
)
return self.le_coc_servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = LeCreditBasedChannelServer(close, spec.psm, handler)
return server
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1719,15 +1853,62 @@ class ChannelManager:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(
self, connection: Connection, cid: int, request
self, connection: Connection, cid: int, request: L2CAP_Connection_Request
) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.Basic(
connection, request.psm, request.source_cid
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, pending.mtu
)
connection_channels[source_cid] = channel
# Notify
pending.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
@@ -1736,41 +1917,13 @@ class ChannelManager:
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.mtu
)
connection_channels[source_cid] = channel
# Notify
server.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_connection_response(
self, connection: Connection, cid: int, response
@@ -1971,108 +2124,135 @@ class ChannelManager:
)
def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request
self,
connection: Connection,
cid: int,
request: L2CAP_LE_Credit_Based_Connection_Request,
) -> None:
if request.le_psm in self.le_coc_servers:
server = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.LeCreditBased(
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
request.initial_credits,
)
# Notify
server.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
pending.mtu,
pending.mps,
request.initial_credits,
request.mtu,
request.mps,
pending.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=pending.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
# Notify
pending.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_le_credit_based_connection_response(
self, connection: Connection, _cid: int, response