Pandora: refactor l2cap service

* Craft the PandoraChannel from the connection_handle and the source_cid
* Fix race on waitDisconnection
* Add ChannelContext to enable mutliple channels on the service
This commit is contained in:
Charlie Boutier
2024-08-30 21:14:31 +00:00
parent 4394a36332
commit 1256170985

View File

@@ -16,10 +16,12 @@ import asyncio
import grpc import grpc
import json import json
import logging import logging
import threading
from asyncio import Queue as AsyncQueue, Future
from . import utils from . import utils
from .config import Config from .config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device from bumble.device import Device
from bumble.l2cap import ( from bumble.l2cap import (
ClassicChannel, ClassicChannel,
@@ -34,7 +36,7 @@ from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD, COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST, INVALID_CID_IN_REQUEST,
Channel, Channel as PandoraChannel,
ConnectRequest, ConnectRequest,
ConnectResponse, ConnectResponse,
CreditBasedChannelRequest, CreditBasedChannelRequest,
@@ -49,7 +51,16 @@ from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
WaitDisconnectionRequest, WaitDisconnectionRequest,
WaitDisconnectionResponse, WaitDisconnectionResponse,
) )
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
class ChannelContext:
close_future: Future
sdu_queue: AsyncQueue
class L2CAPService(L2CAPServicer): class L2CAPService(L2CAPServicer):
@@ -59,7 +70,22 @@ class L2CAPService(L2CAPServicer):
) )
self.device = device self.device = device
self.config = config self.config = config
self.sdu_queue: asyncio.Queue = asyncio.Queue() self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future()
sdu_queue: AsyncQueue = AsyncQueue()
def on_channel_sdu(sdu):
sdu_queue.put_nowait(sdu)
def on_close():
close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close)
return ChannelContext(close_future, sdu_queue)
@utils.rpc @utils.rpc
async def WaitConnection( async def WaitConnection(
@@ -105,18 +131,18 @@ class L2CAPService(L2CAPServicer):
] ]
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}') self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
channel_future: asyncio.Future[Union[ClassicChannel, LeCreditBasedChannel]] = ( channel_future: Future[PandoraChannel] = (
asyncio.get_running_loop().create_future() asyncio.get_running_loop().create_future()
) )
def on_l2cap_channel( def on_l2cap_channel(l2cap_channel: L2capChannel):
l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel]
):
try: try:
channel_future.set_result(l2cap_channel) channel_context = self.register_event(l2cap_channel)
self.log.debug( pandora_channel: PandoraChannel = self.craft_pandora_channel(
f'Channel future set successfully with channel= {l2cap_channel}' connection_handle, l2cap_channel
) )
self.channels[pandora_channel.cookie.value] = channel_context
channel_future.set_result(pandora_channel)
except Exception as e: except Exception as e:
self.log.error(f'Failed to set channel future: {e}') self.log.error(f'Failed to set channel future: {e}')
@@ -129,11 +155,12 @@ class L2CAPService(L2CAPServicer):
try: try:
self.log.debug('Waiting for a channel connection.') self.log.debug('Waiting for a channel connection.')
l2cap_channel = await channel_future pandora_channel: PandoraChannel = await channel_future
channel = self.channel_to_proto(l2cap_channel)
return WaitConnectionResponse(channel=channel) return WaitConnectionResponse(channel=pandora_channel)
except Exception as e: except Exception as e:
self.log.warning(f'Exception: {e}') self.log.warning(f'Exception: {e}')
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD) return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc @utils.rpc
@@ -142,21 +169,13 @@ class L2CAPService(L2CAPServicer):
) -> WaitDisconnectionResponse: ) -> WaitDisconnectionResponse:
try: try:
self.log.debug('WaitDisconnection') self.log.debug('WaitDisconnection')
l2cap_channel = self.get_l2cap_channel(request.channel)
if l2cap_channel is None:
self.log.warn('WaitDisconnection: Unable to find the channel')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
self.log.debug('WaitDisconnection: Sending a disconnection request') await self.lookup_context(request.channel).close_future
closed_event: asyncio.Event = asyncio.Event() self.log.debug("return WaitDisconnectionResponse")
def on_close():
self.log.info('Received a close event')
closed_event.set()
l2cap_channel.on('close', on_close)
await closed_event.wait()
return WaitDisconnectionResponse(success=empty_pb2.Empty()) return WaitDisconnectionResponse(success=empty_pb2.Empty())
except KeyError as e:
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
except Exception as e: except Exception as e:
self.log.exception(f'WaitDisonnection failed: {e}') self.log.exception(f'WaitDisonnection failed: {e}')
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD) return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@@ -168,24 +187,11 @@ class L2CAPService(L2CAPServicer):
self.log.debug('Receive') self.log.debug('Receive')
oneof = request.WhichOneof('source') oneof = request.WhichOneof('source')
self.log.debug(f'Source: {oneof}.') self.log.debug(f'Source: {oneof}.')
channel = getattr(request, oneof) pandora_channel = getattr(request, oneof)
if not isinstance(channel, Channel): sdu_queue = self.lookup_context(pandora_channel).sdu_queue
raise NotImplementedError(f'TODO: {type(channel)} not currently supported.')
def on_channel_sdu(sdu): while sdu := await sdu_queue.get():
async def handle_sdu():
await self.sdu_queue.put(sdu)
asyncio.create_task(handle_sdu())
l2cap_channel = self.get_l2cap_channel(channel)
if l2cap_channel is None:
raise ValueError('The channel in the request is not valid.')
l2cap_channel.sink = on_channel_sdu
while sdu := await self.sdu_queue.get():
# Retrieve the next SDU from the queue
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}') self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
response = ReceiveResponse(data=sdu) response = ReceiveResponse(data=sdu)
yield response yield response
@@ -226,26 +232,30 @@ class L2CAPService(L2CAPServicer):
try: try:
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}') self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
l2cap_channel = await connection.create_l2cap_channel(spec=spec) l2cap_channel = await connection.create_l2cap_channel(spec=spec)
self.log.info(f'L2CAP channel: {l2cap_channel}') channel_context = self.register_event(l2cap_channel)
except Exception as e: pandora_channel = self.craft_pandora_channel(
l2cap_channel = None connection_handle, l2cap_channel
self.log.exception(f'Connection failed: {e}') )
self.channels[pandora_channel.cookie.value] = channel_context
if not l2cap_channel: return ConnectResponse(channel=pandora_channel)
except OutOfResourcesError as e:
self.log.error(e)
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
except InvalidArgumentError as e:
self.log.error(e)
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD) return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
channel = self.channel_to_proto(l2cap_channel)
return ConnectResponse(channel=channel)
@utils.rpc @utils.rpc
async def Disconnect( async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext self, request: DisconnectRequest, context: grpc.ServicerContext
) -> DisconnectResponse: ) -> DisconnectResponse:
try: try:
self.log.debug('Disconnect') self.log.debug('Disconnect')
l2cap_channel = self.get_l2cap_channel(request.channel) l2cap_channel = self.lookup_channel(request.channel)
if not l2cap_channel: if not l2cap_channel:
self.log.warn('Disconnect: Unable to find the channel') self.log.warning('Disconnect: Unable to find the channel')
return DisconnectResponse(error=INVALID_CID_IN_REQUEST) return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
await l2cap_channel.disconnect() await l2cap_channel.disconnect()
@@ -262,13 +272,9 @@ class L2CAPService(L2CAPServicer):
try: try:
oneof = request.WhichOneof('sink') oneof = request.WhichOneof('sink')
self.log.debug(f'Sink: {oneof}.') self.log.debug(f'Sink: {oneof}.')
channel = getattr(request, oneof) pandora_channel = getattr(request, oneof)
if not isinstance(channel, Channel): l2cap_channel = self.lookup_channel(pandora_channel)
raise NotImplementedError(
f'TODO: {type(channel)} not currently supported.'
)
l2cap_channel = self.get_l2cap_channel(channel)
if not l2cap_channel: if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD) return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel): if isinstance(l2cap_channel, ClassicChannel):
@@ -280,43 +286,25 @@ class L2CAPService(L2CAPServicer):
self.log.exception(f'Disonnect failed: {e}') self.log.exception(f'Disonnect failed: {e}')
return SendResponse(error=COMMAND_NOT_UNDERSTOOD) return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
def get_l2cap_channel( def craft_pandora_channel(
self, channel: Channel self,
) -> Optional[Union[ClassicChannel, LeCreditBasedChannel]]: connection_handle: int,
parameters = self.get_channel_parameters(channel) l2cap_channel: L2capChannel,
connection_handle = parameters.get('connection_handle', 0) ) -> PandoraChannel:
destination_cid = parameters.get('destination_cid', 0)
is_classic = parameters.get('is_classic', False)
self.log.debug(
f'get_l2cap_channel: Connection handle:{connection_handle}, cid:{destination_cid}'
)
l2cap_channel: Optional[Union[ClassicChannel, LeCreditBasedChannel]] = None
if is_classic:
l2cap_channel = self.device.l2cap_channel_manager.find_channel(
connection_handle, destination_cid
)
else:
l2cap_channel = self.device.l2cap_channel_manager.find_le_coc_channel(
connection_handle, destination_cid
)
return l2cap_channel
def channel_to_proto(
self, l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel]
) -> Channel:
parameters = { parameters = {
"connection_handle": connection_handle,
"source_cid": l2cap_channel.source_cid, "source_cid": l2cap_channel.source_cid,
"destination_cid": l2cap_channel.destination_cid,
"connection_handle": l2cap_channel.connection.handle,
"is_classic": True if isinstance(l2cap_channel, ClassicChannel) else False,
} }
self.log.info(f'Channel parameters: {parameters}')
cookie = any_pb2.Any() cookie = any_pb2.Any()
cookie.value = json.dumps(parameters).encode() cookie.value = json.dumps(parameters).encode()
return Channel(cookie=cookie) return PandoraChannel(cookie=cookie)
def get_channel_parameters(self, channel: Channel) -> Dict['str', Any]: def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
cookie_value = channel.cookie.value.decode() (connection_handle, source_cid) = json.loads(
parameters = json.loads(cookie_value) pandora_channel.cookie.value
self.log.info(f'Channel parameters: {parameters}') ).values()
return parameters
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
return self.channels[pandora_channel.cookie.value]