From 12561709852231b86b3aac4a1bf8a574c415a6ab Mon Sep 17 00:00:00 2001 From: Charlie Boutier Date: Fri, 30 Aug 2024 21:14:31 +0000 Subject: [PATCH] Pandora: refactor l2cap service * Craft the PandoraChannel from the connection_handle and the source_cid * Fix race on waitDisconnection * Add ChannelContext to enable mutliple channels on the service --- bumble/pandora/l2cap.py | 174 +++++++++++++++++++--------------------- 1 file changed, 81 insertions(+), 93 deletions(-) diff --git a/bumble/pandora/l2cap.py b/bumble/pandora/l2cap.py index f88c4341..488478c6 100644 --- a/bumble/pandora/l2cap.py +++ b/bumble/pandora/l2cap.py @@ -16,10 +16,12 @@ import asyncio import grpc import json import logging -import threading + +from asyncio import Queue as AsyncQueue, Future from . import utils from .config import Config +from bumble.core import OutOfResourcesError, InvalidArgumentError from bumble.device import Device from bumble.l2cap import ( ClassicChannel, @@ -34,7 +36,7 @@ from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error COMMAND_NOT_UNDERSTOOD, INVALID_CID_IN_REQUEST, - Channel, + Channel as PandoraChannel, ConnectRequest, ConnectResponse, CreditBasedChannelRequest, @@ -49,7 +51,16 @@ from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error WaitDisconnectionRequest, WaitDisconnectionResponse, ) -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import AsyncGenerator, Dict, Optional, Union +from dataclasses import dataclass + +L2capChannel = Union[ClassicChannel, LeCreditBasedChannel] + + +@dataclass +class ChannelContext: + close_future: Future + sdu_queue: AsyncQueue class L2CAPService(L2CAPServicer): @@ -59,7 +70,22 @@ class L2CAPService(L2CAPServicer): ) self.device = device self.config = config - self.sdu_queue: asyncio.Queue = asyncio.Queue() + self.channels: Dict[bytes, ChannelContext] = {} + + def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext: + close_future = asyncio.get_running_loop().create_future() + sdu_queue: AsyncQueue = AsyncQueue() + + def on_channel_sdu(sdu): + sdu_queue.put_nowait(sdu) + + def on_close(): + close_future.set_result(None) + + l2cap_channel.sink = on_channel_sdu + l2cap_channel.on('close', on_close) + + return ChannelContext(close_future, sdu_queue) @utils.rpc async def WaitConnection( @@ -105,18 +131,18 @@ class L2CAPService(L2CAPServicer): ] self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}') - channel_future: asyncio.Future[Union[ClassicChannel, LeCreditBasedChannel]] = ( + channel_future: Future[PandoraChannel] = ( asyncio.get_running_loop().create_future() ) - def on_l2cap_channel( - l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel] - ): + def on_l2cap_channel(l2cap_channel: L2capChannel): try: - channel_future.set_result(l2cap_channel) - self.log.debug( - f'Channel future set successfully with channel= {l2cap_channel}' + channel_context = self.register_event(l2cap_channel) + pandora_channel: PandoraChannel = self.craft_pandora_channel( + connection_handle, l2cap_channel ) + self.channels[pandora_channel.cookie.value] = channel_context + channel_future.set_result(pandora_channel) except Exception as e: self.log.error(f'Failed to set channel future: {e}') @@ -129,11 +155,12 @@ class L2CAPService(L2CAPServicer): try: self.log.debug('Waiting for a channel connection.') - l2cap_channel = await channel_future - channel = self.channel_to_proto(l2cap_channel) - return WaitConnectionResponse(channel=channel) + pandora_channel: PandoraChannel = await channel_future + + return WaitConnectionResponse(channel=pandora_channel) except Exception as e: self.log.warning(f'Exception: {e}') + return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD) @utils.rpc @@ -142,21 +169,13 @@ class L2CAPService(L2CAPServicer): ) -> WaitDisconnectionResponse: try: self.log.debug('WaitDisconnection') - l2cap_channel = self.get_l2cap_channel(request.channel) - if l2cap_channel is None: - self.log.warn('WaitDisconnection: Unable to find the channel') - return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST) - self.log.debug('WaitDisconnection: Sending a disconnection request') - closed_event: asyncio.Event = asyncio.Event() - - def on_close(): - self.log.info('Received a close event') - closed_event.set() - - l2cap_channel.on('close', on_close) - await closed_event.wait() + await self.lookup_context(request.channel).close_future + self.log.debug("return WaitDisconnectionResponse") return WaitDisconnectionResponse(success=empty_pb2.Empty()) + except KeyError as e: + self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}') + return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST) except Exception as e: self.log.exception(f'WaitDisonnection failed: {e}') return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD) @@ -168,24 +187,11 @@ class L2CAPService(L2CAPServicer): self.log.debug('Receive') oneof = request.WhichOneof('source') self.log.debug(f'Source: {oneof}.') - channel = getattr(request, oneof) + pandora_channel = getattr(request, oneof) - if not isinstance(channel, Channel): - raise NotImplementedError(f'TODO: {type(channel)} not currently supported.') + sdu_queue = self.lookup_context(pandora_channel).sdu_queue - def on_channel_sdu(sdu): - async def handle_sdu(): - await self.sdu_queue.put(sdu) - - asyncio.create_task(handle_sdu()) - - l2cap_channel = self.get_l2cap_channel(channel) - if l2cap_channel is None: - raise ValueError('The channel in the request is not valid.') - - l2cap_channel.sink = on_channel_sdu - while sdu := await self.sdu_queue.get(): - # Retrieve the next SDU from the queue + while sdu := await sdu_queue.get(): self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}') response = ReceiveResponse(data=sdu) yield response @@ -226,26 +232,30 @@ class L2CAPService(L2CAPServicer): try: self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}') l2cap_channel = await connection.create_l2cap_channel(spec=spec) - self.log.info(f'L2CAP channel: {l2cap_channel}') - except Exception as e: - l2cap_channel = None - self.log.exception(f'Connection failed: {e}') + channel_context = self.register_event(l2cap_channel) + pandora_channel = self.craft_pandora_channel( + connection_handle, l2cap_channel + ) + self.channels[pandora_channel.cookie.value] = channel_context - if not l2cap_channel: + return ConnectResponse(channel=pandora_channel) + + except OutOfResourcesError as e: + self.log.error(e) + return ConnectResponse(error=INVALID_CID_IN_REQUEST) + except InvalidArgumentError as e: + self.log.error(e) return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD) - channel = self.channel_to_proto(l2cap_channel) - return ConnectResponse(channel=channel) - @utils.rpc async def Disconnect( self, request: DisconnectRequest, context: grpc.ServicerContext ) -> DisconnectResponse: try: self.log.debug('Disconnect') - l2cap_channel = self.get_l2cap_channel(request.channel) + l2cap_channel = self.lookup_channel(request.channel) if not l2cap_channel: - self.log.warn('Disconnect: Unable to find the channel') + self.log.warning('Disconnect: Unable to find the channel') return DisconnectResponse(error=INVALID_CID_IN_REQUEST) await l2cap_channel.disconnect() @@ -262,13 +272,9 @@ class L2CAPService(L2CAPServicer): try: oneof = request.WhichOneof('sink') self.log.debug(f'Sink: {oneof}.') - channel = getattr(request, oneof) + pandora_channel = getattr(request, oneof) - if not isinstance(channel, Channel): - raise NotImplementedError( - f'TODO: {type(channel)} not currently supported.' - ) - l2cap_channel = self.get_l2cap_channel(channel) + l2cap_channel = self.lookup_channel(pandora_channel) if not l2cap_channel: return SendResponse(error=COMMAND_NOT_UNDERSTOOD) if isinstance(l2cap_channel, ClassicChannel): @@ -280,43 +286,25 @@ class L2CAPService(L2CAPServicer): self.log.exception(f'Disonnect failed: {e}') return SendResponse(error=COMMAND_NOT_UNDERSTOOD) - def get_l2cap_channel( - self, channel: Channel - ) -> Optional[Union[ClassicChannel, LeCreditBasedChannel]]: - parameters = self.get_channel_parameters(channel) - connection_handle = parameters.get('connection_handle', 0) - destination_cid = parameters.get('destination_cid', 0) - is_classic = parameters.get('is_classic', False) - self.log.debug( - f'get_l2cap_channel: Connection handle:{connection_handle}, cid:{destination_cid}' - ) - l2cap_channel: Optional[Union[ClassicChannel, LeCreditBasedChannel]] = None - if is_classic: - l2cap_channel = self.device.l2cap_channel_manager.find_channel( - connection_handle, destination_cid - ) - else: - l2cap_channel = self.device.l2cap_channel_manager.find_le_coc_channel( - connection_handle, destination_cid - ) - return l2cap_channel - - def channel_to_proto( - self, l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel] - ) -> Channel: + def craft_pandora_channel( + self, + connection_handle: int, + l2cap_channel: L2capChannel, + ) -> PandoraChannel: parameters = { + "connection_handle": connection_handle, "source_cid": l2cap_channel.source_cid, - "destination_cid": l2cap_channel.destination_cid, - "connection_handle": l2cap_channel.connection.handle, - "is_classic": True if isinstance(l2cap_channel, ClassicChannel) else False, } - self.log.info(f'Channel parameters: {parameters}') cookie = any_pb2.Any() cookie.value = json.dumps(parameters).encode() - return Channel(cookie=cookie) + return PandoraChannel(cookie=cookie) - def get_channel_parameters(self, channel: Channel) -> Dict['str', Any]: - cookie_value = channel.cookie.value.decode() - parameters = json.loads(cookie_value) - self.log.info(f'Channel parameters: {parameters}') - return parameters + def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel: + (connection_handle, source_cid) = json.loads( + pandora_channel.cookie.value + ).values() + + return self.device.l2cap_channel_manager.channels[connection_handle][source_cid] + + def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext: + return self.channels[pandora_channel.cookie.value]