Compare commits

...

2 Commits

Author SHA1 Message Date
uael
412fd0f78a pandora: implement L2CAP pandora service
Co-authored-by: Josh Wu <joshwu@google.com>
2023-11-07 00:58:33 -08:00
uael
ee494a6543 l2cap: refactor server side to allow deferred accept
In order to avoid any breaking changes this re-impl current APIs with
the exact same behavior.

The previous impl was preventing one to defer the response to an l2cap
channel connection request, both for BR/EDR basic channels and LE credit
based ones. This commit change this to spawn a task on every channel
incoming connection request, then all registered listeners are given a
chance to accept it through a `asyncio.Future`. After a bit of delay, if
none had accepted it, the connection is automatically rejected.
2023-11-07 00:43:02 -08:00
4 changed files with 682 additions and 210 deletions

View File

@@ -35,8 +35,10 @@ from typing import (
Union,
Deque,
Iterable,
Set,
SupportsBytes,
TYPE_CHECKING,
overload,
)
from .utils import deprecated
@@ -237,6 +239,8 @@ class L2CAP_Control_Frame:
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name: str
identifier: int
pdu: bytes
@staticmethod
def from_bytes(pdu: bytes) -> L2CAP_Control_Frame:
@@ -391,6 +395,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
psm: int
source_cid: int
@staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
@@ -637,7 +644,11 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
(CODE 0x14)
'''
le_psm: int
source_cid: int
mtu: int
mps: int
initial_credits: int
# -----------------------------------------------------------------------------
@@ -1375,19 +1386,14 @@ class LeCreditBasedChannel(EventEmitter):
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[ClassicChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
@@ -1395,28 +1401,18 @@ class ClassicChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[LeCreditBasedChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
@@ -1424,21 +1420,107 @@ class LeCreditBasedChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
class PendingConnection:
"""
All pending connection types.
A `PendingConnection` is a temporary object used to accept an incoming connection
request, it contains the acceptor channel configuration preferences and transition
to the connected state through the `on_connection` callback.
This object is not supposed to live anymore once the channel is connected.
"""
class Any:
"""L2CAP any channel pending connection."""
on_connection: Callable[[Any], None]
mtu: int
@dataclasses.dataclass
class Basic(Any):
"""L2CAP basic channel pending connection."""
on_connection: Callable[[ClassicChannel], None] = lambda _: None
mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP LE credit based channel pending connection."""
on_connection: Callable[[LeCreditBasedChannel], None] = lambda _: None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
class IncomingConnection:
"""
All incoming connection types.
A `IncomingConnection` is a temporary object used to notify listeners of an
incoming channel connection request. It can accepted through the `future` field.
Multiple listeners can observe the same incoming connection request, but no more
than one can actually accept, first come first served. Thus it's recommended for
delayed accept to before check the state of the future field.
This object is not supposed to live anymore once accepted.
Example:
```python
fut = asyncio.Future()
def listener(incoming: IncomingConnection.Any) -> None:
if isinstance(incoming, IncomingConnection.Basic) and incoming.psm == 0xcafe:
incoming.future.set_result(PendingConnection.Basic(fut.set_result, mtu=123))
device.l2cap_manager.listen(listener)
channel = await fut
```
"""
@dataclasses.dataclass
class Any:
"""L2CAP any incoming channel connection request."""
connection: Connection
psm: int
source_cid: int
def __post_init__(self) -> None:
self.future: asyncio.Future[Any] = asyncio.Future()
@dataclasses.dataclass
class Basic(Any):
"""L2CAP incoming basic channel connection request."""
future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False)
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP incoming LE credit based channel connection request."""
mtu: int
mps: int
initial_credits: int
future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field(
init=False
)
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]]
listeners: List[Callable[[IncomingConnection.Any], None]]
used_psm: Set[int]
def __init__(
self,
@@ -1452,15 +1534,15 @@ class ChannelManager:
L2CAP_SIGNALING_CID: None,
L2CAP_LE_SIGNALING_CID: None,
}
self.servers = {} # Servers accepting connections, by PSM
self.le_coc_channels = (
{}
) # LE CoC channels, mapped by connection and destination cid
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None
self.listeners = []
self.used_psm = set()
@property
def host(self) -> Host:
@@ -1513,6 +1595,31 @@ class ChannelManager:
raise RuntimeError('no free CID')
def allocate_psm(self) -> int:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def allocate_spsm(self) -> int:
# Find a free sPSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def free_psm(self, psm: int) -> None:
self.used_psm.remove(psm)
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -1527,6 +1634,35 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@overload
def listen(
self, cb: Callable[[IncomingConnection.Basic], None]
) -> Callable[[IncomingConnection.Basic], None]:
...
@overload
def listen(
self, cb: Callable[[IncomingConnection.LeCreditBased], None]
) -> Callable[[IncomingConnection.LeCreditBased], None]:
...
def listen(self, cb: Any) -> Any:
if cb in self.listeners:
raise ValueError('listener already registered')
self.listeners.append(cb)
return cb
@overload
def unlisten(self, cb: Callable[[IncomingConnection.Basic], None]) -> None:
...
@overload
def unlisten(self, cb: Callable[[IncomingConnection.LeCreditBased], None]) -> None:
...
def unlisten(self, cb: Any) -> None:
self.listeners.remove(cb)
@deprecated("Please use create_classic_server")
def register_server(
self,
@@ -1534,7 +1670,7 @@ class ChannelManager:
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
handler=server, spec=ClassicChannelSpec(psm=None if psm == 0 else psm)
).psm
def create_classic_server(
@@ -1542,24 +1678,12 @@ class ChannelManager:
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: ClassicChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
@@ -1568,10 +1692,22 @@ class ChannelManager:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.used_psm.add(spec.psm)
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
def listener(incoming: IncomingConnection.Basic) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.Basic(server.on_connection, spec.mtu)
)
return self.servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = ClassicChannelServer(close, spec.psm, handler)
return server
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
@@ -1594,32 +1730,30 @@ class ChannelManager:
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.le_coc_servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: LeCreditBasedChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: SPSM already in use')
self.used_psm.add(spec.psm)
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
spec.psm,
handler,
max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
)
def listener(incoming: IncomingConnection.LeCreditBased) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.LeCreditBased(
server.on_connection, spec.mtu, spec.mps, spec.max_credits
)
)
return self.le_coc_servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = LeCreditBasedChannelServer(close, spec.psm, handler)
return server
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1719,15 +1853,62 @@ class ChannelManager:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(
self, connection: Connection, cid: int, request
self, connection: Connection, cid: int, request: L2CAP_Connection_Request
) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.Basic(
connection, request.psm, request.source_cid
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, pending.mtu
)
connection_channels[source_cid] = channel
# Notify
pending.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
@@ -1736,41 +1917,13 @@ class ChannelManager:
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.mtu
)
connection_channels[source_cid] = channel
# Notify
server.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_connection_response(
self, connection: Connection, cid: int, response
@@ -1971,108 +2124,135 @@ class ChannelManager:
)
def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request
self,
connection: Connection,
cid: int,
request: L2CAP_LE_Credit_Based_Connection_Request,
) -> None:
if request.le_psm in self.le_coc_servers:
server = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.LeCreditBased(
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
request.initial_credits,
)
# Notify
server.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
pending.mtu,
pending.mps,
request.initial_credits,
request.mtu,
request.mps,
pending.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=pending.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
# Notify
pending.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_le_credit_based_connection_response(
self, connection: Connection, _cid: int, response

View File

@@ -26,11 +26,13 @@ from .config import Config
from .device import PandoraDevice
from .host import HostService
from .security import SecurityService, SecurityStorageService
from .l2cap import L2CAPService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from typing import Callable, List, Optional
# public symbols
@@ -77,6 +79,7 @@ async def serve(
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
# call hooks if any.
for hook in _SERVICERS_HOOKS:

289
bumble/pandora/l2cap.py Normal file
View File

@@ -0,0 +1,289 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import grpc
import struct
from bumble import device
from bumble import l2cap
from bumble.pandora import config
from bumble.pandora import utils
from bumble.utils import EventWatcher
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora import l2cap_pb2
from pandora import l2cap_grpc_aio
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Union
@dataclasses.dataclass
class ChannelProxy:
channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel, None]
def __post_init__(self) -> None:
assert self.channel
self.rx: asyncio.Queue[bytes] = asyncio.Queue()
self._disconnection_result: asyncio.Future[None] = asyncio.Future()
self.channel.sink = self.rx.put_nowait
def on_close() -> None:
assert not self._disconnection_result.done()
self.channel = None
self._disconnection_result.set_result(None)
self.channel.on('close', on_close)
def send(self, data: bytes) -> None:
assert self.channel
if isinstance(self.channel, l2cap.ClassicChannel):
self.channel.send_pdu(data)
else:
self.channel.write(data)
async def disconnect(self) -> None:
assert self.channel
await self.channel.disconnect()
async def wait_disconnect(self) -> None:
await self._disconnection_result
assert not self.channel
@dataclasses.dataclass
class ChannelIndex:
connection_handle: int
cid: int
@classmethod
def from_token(cls, token: l2cap_pb2.Channel) -> 'ChannelIndex':
connection_handle, cid = struct.unpack('>HH', token.cookie.value)
return cls(connection_handle, cid)
def into_token(self) -> l2cap_pb2.Channel:
return l2cap_pb2.Channel(
cookie=any_pb2.Any(
value=struct.pack('>HH', self.connection_handle, self.cid)
)
)
def __hash__(self):
return hash(self.connection_handle | (self.cid << 12))
class L2CAPService(l2cap_grpc_aio.L2CAPServicer):
channels: Dict[ChannelIndex, ChannelProxy] = {}
pending: List[l2cap.IncomingConnection.Any] = []
accepts: List[asyncio.Queue[l2cap.IncomingConnection.Any]] = []
def __init__(self, dev: device.Device, config: config.Config) -> None:
self.device = dev
self.config = config
def on_connection(incoming: l2cap.IncomingConnection.Any) -> None:
self.pending.append(incoming)
for acceptor in self.accepts:
acceptor.put_nowait(incoming)
# Make sure our listener is called before the builtins ones.
self.device.l2cap_channel_manager.listeners.insert(0, on_connection)
def register(self, index: ChannelIndex, proxy: ChannelProxy) -> None:
self.channels[index] = proxy
def on_close(*_: Any) -> None:
# TODO: Fix Bumble L2CAP which emit `close` event twice.
if index in self.channels:
del self.channels[index]
# Listen for disconnection.
assert proxy.channel
proxy.channel.on('close', on_close)
async def listen(self) -> AsyncIterator[l2cap.IncomingConnection.Any]:
for incoming in self.pending:
if incoming.future.done():
self.pending.remove(incoming)
continue
yield incoming
queue: asyncio.Queue[l2cap.IncomingConnection.Any] = asyncio.Queue()
self.accepts.append(queue)
try:
while incoming := await queue.get():
yield incoming
finally:
self.accepts.remove(queue)
@utils.rpc
async def Connect(
self, request: l2cap_pb2.ConnectRequest, context: grpc.ServicerContext
) -> l2cap_pb2.ConnectResponse:
# Retrieve Bumble `Connection` from request.
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if connection is None:
raise RuntimeError(f'{connection_handle}: not connection for handle')
channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]
if request.type_variant() == 'basic':
assert request.basic
channel = await connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(
psm=request.basic.psm, mtu=request.basic.mtu
)
)
elif request.type_variant() == 'le_credit_based':
assert request.le_credit_based
channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=request.le_credit_based.spsm,
max_credits=request.le_credit_based.initial_credit,
mtu=request.le_credit_based.mtu,
mps=request.le_credit_based.mps,
)
)
else:
raise NotImplementedError(f"{request.type_variant()}: unsupported type")
index = ChannelIndex(channel.connection.handle, channel.source_cid)
self.register(index, ChannelProxy(channel))
return l2cap_pb2.ConnectResponse(channel=index.into_token())
@utils.rpc
async def WaitConnection(
self, request: l2cap_pb2.WaitConnectionRequest, context: grpc.ServicerContext
) -> l2cap_pb2.WaitConnectionResponse:
iter = self.listen()
fut: asyncio.Future[
Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]
] = asyncio.Future()
# Filter by connection.
if request.connection:
handle = int.from_bytes(request.connection.cookie.value, 'big')
iter = (it async for it in iter if it.connection.handle == handle)
if request.type_variant() == 'basic':
assert request.basic
basic = l2cap.PendingConnection.Basic(
fut.set_result,
request.basic.mtu or l2cap.L2CAP_MIN_BR_EDR_MTU,
)
async for i in (
it
async for it in iter
if isinstance(it, l2cap.IncomingConnection.Basic)
):
if not i.future.done() and i.psm == request.basic.psm:
i.future.set_result(basic)
break
elif request.type_variant() == 'le_credit_based':
assert request.le_credit_based
le_credit_based = l2cap.PendingConnection.LeCreditBased(
fut.set_result,
request.le_credit_based.mtu
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
request.le_credit_based.mps
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
request.le_credit_based.initial_credit
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
)
async for j in (
it
async for it in iter
if isinstance(it, l2cap.IncomingConnection.LeCreditBased)
):
if not j.future.done() and j.psm == request.le_credit_based.spsm:
j.future.set_result(le_credit_based)
break
else:
raise NotImplementedError(f"{request.type_variant()}: unsupported type")
channel = await fut
index = ChannelIndex(channel.connection.handle, channel.source_cid)
self.register(index, ChannelProxy(channel))
return l2cap_pb2.WaitConnectionResponse(channel=index.into_token())
@utils.rpc
async def Disconnect(
self, request: l2cap_pb2.DisconnectRequest, context: grpc.ServicerContext
) -> l2cap_pb2.DisconnectResponse:
channel = self.channels[ChannelIndex.from_token(request.channel)]
await channel.disconnect()
return l2cap_pb2.DisconnectResponse(success=empty_pb2.Empty())
@utils.rpc
async def WaitDisconnection(
self, request: l2cap_pb2.WaitDisconnectionRequest, context: grpc.ServicerContext
) -> l2cap_pb2.WaitDisconnectionResponse:
channel = self.channels[ChannelIndex.from_token(request.channel)]
await channel.wait_disconnect()
return l2cap_pb2.WaitDisconnectionResponse(success=empty_pb2.Empty())
@utils.rpc
async def Receive(
self, request: l2cap_pb2.ReceiveRequest, context: grpc.ServicerContext
) -> AsyncGenerator[l2cap_pb2.ReceiveResponse, None]:
watcher = EventWatcher()
if request.source_variant() == 'channel':
assert request.channel
channel = self.channels[ChannelIndex.from_token(request.channel)]
rx = channel.rx
elif request.source_variant() == 'fixed_channel':
assert request.fixed_channel
rx = asyncio.Queue()
handle = request.fixed_channel.connection is not None and int.from_bytes(
request.fixed_channel.connection.cookie.value, 'big'
)
@watcher.on(self.device.host, 'l2cap_pdu')
def _(connection: device.Connection, cid: int, pdu: bytes) -> None:
assert request.fixed_channel
if cid == request.fixed_channel.cid and (
handle is None or handle == connection.handle
):
rx.put_nowait(pdu)
else:
raise NotImplementedError(f"{request.source_variant()}: unsupported type")
try:
while data := await rx.get():
yield l2cap_pb2.ReceiveResponse(data=data)
finally:
watcher.close()
@utils.rpc
async def Send(
self, request: l2cap_pb2.SendRequest, context: grpc.ServicerContext
) -> l2cap_pb2.SendResponse:
if request.sink_variant() == 'channel':
assert request.channel
channel = self.channels[ChannelIndex.from_token(request.channel)]
channel.send(request.data)
elif request.sink_variant() == 'fixed_channel':
assert request.fixed_channel
# Retrieve Bumble `Connection` from request.
connection_handle = int.from_bytes(
request.fixed_channel.connection.cookie.value, 'big'
)
connection = self.device.lookup_connection(connection_handle)
if connection is None:
raise RuntimeError(f'{connection_handle}: not connection for handle')
self.device.l2cap_channel_manager.send_pdu(
connection, request.fixed_channel.cid, request.data
)
else:
raise NotImplementedError(f"{request.sink_variant()}: unsupported type")
return l2cap_pb2.SendResponse(success=empty_pb2.Empty())

View File

@@ -33,7 +33,7 @@ include_package_data = True
install_requires =
aiohttp ~= 3.8; platform_system!='Emscripten'
appdirs >= 1.4; platform_system!='Emscripten'
bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
bt-test-interfaces >= 0.0.5; platform_system!='Emscripten'
click == 8.1.3; platform_system!='Emscripten'
cryptography == 39; platform_system!='Emscripten'
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the