Pandora: refactor l2cap service

* Craft the PandoraChannel from the connection_handle and the source_cid
* Fix race on waitDisconnection
* Add ChannelContext to enable mutliple channels on the service
This commit is contained in:
Charlie Boutier
2024-08-30 21:14:31 +00:00
parent 4394a36332
commit 1256170985

View File

@@ -16,10 +16,12 @@ import asyncio
import grpc
import json
import logging
import threading
from asyncio import Queue as AsyncQueue, Future
from . import utils
from .config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
@@ -34,7 +36,7 @@ from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
Channel,
Channel as PandoraChannel,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
@@ -49,7 +51,16 @@ from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from typing import Any, AsyncGenerator, Dict, Optional, Union
from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
class ChannelContext:
close_future: Future
sdu_queue: AsyncQueue
class L2CAPService(L2CAPServicer):
@@ -59,7 +70,22 @@ class L2CAPService(L2CAPServicer):
)
self.device = device
self.config = config
self.sdu_queue: asyncio.Queue = asyncio.Queue()
self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future()
sdu_queue: AsyncQueue = AsyncQueue()
def on_channel_sdu(sdu):
sdu_queue.put_nowait(sdu)
def on_close():
close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close)
return ChannelContext(close_future, sdu_queue)
@utils.rpc
async def WaitConnection(
@@ -105,18 +131,18 @@ class L2CAPService(L2CAPServicer):
]
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
channel_future: asyncio.Future[Union[ClassicChannel, LeCreditBasedChannel]] = (
channel_future: Future[PandoraChannel] = (
asyncio.get_running_loop().create_future()
)
def on_l2cap_channel(
l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel]
):
def on_l2cap_channel(l2cap_channel: L2capChannel):
try:
channel_future.set_result(l2cap_channel)
self.log.debug(
f'Channel future set successfully with channel= {l2cap_channel}'
channel_context = self.register_event(l2cap_channel)
pandora_channel: PandoraChannel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
channel_future.set_result(pandora_channel)
except Exception as e:
self.log.error(f'Failed to set channel future: {e}')
@@ -129,11 +155,12 @@ class L2CAPService(L2CAPServicer):
try:
self.log.debug('Waiting for a channel connection.')
l2cap_channel = await channel_future
channel = self.channel_to_proto(l2cap_channel)
return WaitConnectionResponse(channel=channel)
pandora_channel: PandoraChannel = await channel_future
return WaitConnectionResponse(channel=pandora_channel)
except Exception as e:
self.log.warning(f'Exception: {e}')
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
@@ -142,21 +169,13 @@ class L2CAPService(L2CAPServicer):
) -> WaitDisconnectionResponse:
try:
self.log.debug('WaitDisconnection')
l2cap_channel = self.get_l2cap_channel(request.channel)
if l2cap_channel is None:
self.log.warn('WaitDisconnection: Unable to find the channel')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
self.log.debug('WaitDisconnection: Sending a disconnection request')
closed_event: asyncio.Event = asyncio.Event()
def on_close():
self.log.info('Received a close event')
closed_event.set()
l2cap_channel.on('close', on_close)
await closed_event.wait()
await self.lookup_context(request.channel).close_future
self.log.debug("return WaitDisconnectionResponse")
return WaitDisconnectionResponse(success=empty_pb2.Empty())
except KeyError as e:
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
except Exception as e:
self.log.exception(f'WaitDisonnection failed: {e}')
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@@ -168,24 +187,11 @@ class L2CAPService(L2CAPServicer):
self.log.debug('Receive')
oneof = request.WhichOneof('source')
self.log.debug(f'Source: {oneof}.')
channel = getattr(request, oneof)
pandora_channel = getattr(request, oneof)
if not isinstance(channel, Channel):
raise NotImplementedError(f'TODO: {type(channel)} not currently supported.')
sdu_queue = self.lookup_context(pandora_channel).sdu_queue
def on_channel_sdu(sdu):
async def handle_sdu():
await self.sdu_queue.put(sdu)
asyncio.create_task(handle_sdu())
l2cap_channel = self.get_l2cap_channel(channel)
if l2cap_channel is None:
raise ValueError('The channel in the request is not valid.')
l2cap_channel.sink = on_channel_sdu
while sdu := await self.sdu_queue.get():
# Retrieve the next SDU from the queue
while sdu := await sdu_queue.get():
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
response = ReceiveResponse(data=sdu)
yield response
@@ -226,26 +232,30 @@ class L2CAPService(L2CAPServicer):
try:
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
l2cap_channel = await connection.create_l2cap_channel(spec=spec)
self.log.info(f'L2CAP channel: {l2cap_channel}')
except Exception as e:
l2cap_channel = None
self.log.exception(f'Connection failed: {e}')
channel_context = self.register_event(l2cap_channel)
pandora_channel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
if not l2cap_channel:
return ConnectResponse(channel=pandora_channel)
except OutOfResourcesError as e:
self.log.error(e)
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
except InvalidArgumentError as e:
self.log.error(e)
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
channel = self.channel_to_proto(l2cap_channel)
return ConnectResponse(channel=channel)
@utils.rpc
async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> DisconnectResponse:
try:
self.log.debug('Disconnect')
l2cap_channel = self.get_l2cap_channel(request.channel)
l2cap_channel = self.lookup_channel(request.channel)
if not l2cap_channel:
self.log.warn('Disconnect: Unable to find the channel')
self.log.warning('Disconnect: Unable to find the channel')
return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
await l2cap_channel.disconnect()
@@ -262,13 +272,9 @@ class L2CAPService(L2CAPServicer):
try:
oneof = request.WhichOneof('sink')
self.log.debug(f'Sink: {oneof}.')
channel = getattr(request, oneof)
pandora_channel = getattr(request, oneof)
if not isinstance(channel, Channel):
raise NotImplementedError(
f'TODO: {type(channel)} not currently supported.'
)
l2cap_channel = self.get_l2cap_channel(channel)
l2cap_channel = self.lookup_channel(pandora_channel)
if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel):
@@ -280,43 +286,25 @@ class L2CAPService(L2CAPServicer):
self.log.exception(f'Disonnect failed: {e}')
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
def get_l2cap_channel(
self, channel: Channel
) -> Optional[Union[ClassicChannel, LeCreditBasedChannel]]:
parameters = self.get_channel_parameters(channel)
connection_handle = parameters.get('connection_handle', 0)
destination_cid = parameters.get('destination_cid', 0)
is_classic = parameters.get('is_classic', False)
self.log.debug(
f'get_l2cap_channel: Connection handle:{connection_handle}, cid:{destination_cid}'
)
l2cap_channel: Optional[Union[ClassicChannel, LeCreditBasedChannel]] = None
if is_classic:
l2cap_channel = self.device.l2cap_channel_manager.find_channel(
connection_handle, destination_cid
)
else:
l2cap_channel = self.device.l2cap_channel_manager.find_le_coc_channel(
connection_handle, destination_cid
)
return l2cap_channel
def channel_to_proto(
self, l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel]
) -> Channel:
def craft_pandora_channel(
self,
connection_handle: int,
l2cap_channel: L2capChannel,
) -> PandoraChannel:
parameters = {
"connection_handle": connection_handle,
"source_cid": l2cap_channel.source_cid,
"destination_cid": l2cap_channel.destination_cid,
"connection_handle": l2cap_channel.connection.handle,
"is_classic": True if isinstance(l2cap_channel, ClassicChannel) else False,
}
self.log.info(f'Channel parameters: {parameters}')
cookie = any_pb2.Any()
cookie.value = json.dumps(parameters).encode()
return Channel(cookie=cookie)
return PandoraChannel(cookie=cookie)
def get_channel_parameters(self, channel: Channel) -> Dict['str', Any]:
cookie_value = channel.cookie.value.decode()
parameters = json.loads(cookie_value)
self.log.info(f'Channel parameters: {parameters}')
return parameters
def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
(connection_handle, source_cid) = json.loads(
pandora_channel.cookie.value
).values()
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
return self.channels[pandora_channel.cookie.value]