diff --git a/bumble/pandora/__init__.py b/bumble/pandora/__init__.py index e02f54a..8fb4b6e 100644 --- a/bumble/pandora/__init__.py +++ b/bumble/pandora/__init__.py @@ -25,8 +25,10 @@ import grpc.aio from .config import Config from .device import PandoraDevice from .host import HostService +from .l2cap import L2CAPService from .security import SecurityService, SecurityStorageService from pandora.host_grpc_aio import add_HostServicer_to_server +from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server from pandora.security_grpc_aio import ( add_SecurityServicer_to_server, add_SecurityStorageServicer_to_server, @@ -77,6 +79,7 @@ async def serve( add_SecurityStorageServicer_to_server( SecurityStorageService(bumble.device, config), server ) + add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server) # call hooks if any. for hook in _SERVICERS_HOOKS: diff --git a/bumble/pandora/l2cap.py b/bumble/pandora/l2cap.py new file mode 100644 index 0000000..f88c434 --- /dev/null +++ b/bumble/pandora/l2cap.py @@ -0,0 +1,322 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +import asyncio +import grpc +import json +import logging +import threading + +from . import utils +from .config import Config +from bumble.device import Device +from bumble.l2cap import ( + ClassicChannel, + ClassicChannelServer, + ClassicChannelSpec, + LeCreditBasedChannel, + LeCreditBasedChannelServer, + LeCreditBasedChannelSpec, +) +from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error +from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error +from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error + COMMAND_NOT_UNDERSTOOD, + INVALID_CID_IN_REQUEST, + Channel, + ConnectRequest, + ConnectResponse, + CreditBasedChannelRequest, + DisconnectRequest, + DisconnectResponse, + ReceiveRequest, + ReceiveResponse, + SendRequest, + SendResponse, + WaitConnectionRequest, + WaitConnectionResponse, + WaitDisconnectionRequest, + WaitDisconnectionResponse, +) +from typing import Any, AsyncGenerator, Dict, Optional, Union + + +class L2CAPService(L2CAPServicer): + def __init__(self, device: Device, config: Config) -> None: + self.log = utils.BumbleServerLoggerAdapter( + logging.getLogger(), {'service_name': 'L2CAP', 'device': device} + ) + self.device = device + self.config = config + self.sdu_queue: asyncio.Queue = asyncio.Queue() + + @utils.rpc + async def WaitConnection( + self, request: WaitConnectionRequest, context: grpc.ServicerContext + ) -> WaitConnectionResponse: + self.log.debug('WaitConnection') + if not request.connection: + raise ValueError('A valid connection field must be set') + + # find connection on device based on connection cookie value + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + connection = self.device.lookup_connection(connection_handle) + + if not connection: + raise ValueError('The connection specified is invalid.') + + oneof = request.WhichOneof('type') + self.log.debug(f'WaitConnection channel request type: {oneof}.') + channel_type = getattr(request, oneof) + spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None + l2cap_server: Optional[ + Union[ClassicChannelServer, LeCreditBasedChannelServer] + ] = None + if isinstance(channel_type, CreditBasedChannelRequest): + spec = LeCreditBasedChannelSpec( + psm=channel_type.spsm, + max_credits=channel_type.initial_credit, + mtu=channel_type.mtu, + mps=channel_type.mps, + ) + if channel_type.spsm in self.device.l2cap_channel_manager.le_coc_servers: + l2cap_server = self.device.l2cap_channel_manager.le_coc_servers[ + channel_type.spsm + ] + else: + spec = ClassicChannelSpec( + psm=channel_type.psm, + mtu=channel_type.mtu, + ) + if channel_type.psm in self.device.l2cap_channel_manager.servers: + l2cap_server = self.device.l2cap_channel_manager.servers[ + channel_type.psm + ] + + self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}') + channel_future: asyncio.Future[Union[ClassicChannel, LeCreditBasedChannel]] = ( + asyncio.get_running_loop().create_future() + ) + + def on_l2cap_channel( + l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel] + ): + try: + channel_future.set_result(l2cap_channel) + self.log.debug( + f'Channel future set successfully with channel= {l2cap_channel}' + ) + except Exception as e: + self.log.error(f'Failed to set channel future: {e}') + + if l2cap_server is None: + l2cap_server = self.device.create_l2cap_server( + spec=spec, handler=on_l2cap_channel + ) + else: + l2cap_server.on('connection', on_l2cap_channel) + + try: + self.log.debug('Waiting for a channel connection.') + l2cap_channel = await channel_future + channel = self.channel_to_proto(l2cap_channel) + return WaitConnectionResponse(channel=channel) + except Exception as e: + self.log.warning(f'Exception: {e}') + return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD) + + @utils.rpc + async def WaitDisconnection( + self, request: WaitDisconnectionRequest, context: grpc.ServicerContext + ) -> WaitDisconnectionResponse: + try: + self.log.debug('WaitDisconnection') + l2cap_channel = self.get_l2cap_channel(request.channel) + if l2cap_channel is None: + self.log.warn('WaitDisconnection: Unable to find the channel') + return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST) + + self.log.debug('WaitDisconnection: Sending a disconnection request') + closed_event: asyncio.Event = asyncio.Event() + + def on_close(): + self.log.info('Received a close event') + closed_event.set() + + l2cap_channel.on('close', on_close) + await closed_event.wait() + return WaitDisconnectionResponse(success=empty_pb2.Empty()) + except Exception as e: + self.log.exception(f'WaitDisonnection failed: {e}') + return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD) + + @utils.rpc + async def Receive( + self, request: ReceiveRequest, context: grpc.ServicerContext + ) -> AsyncGenerator[ReceiveResponse, None]: + self.log.debug('Receive') + oneof = request.WhichOneof('source') + self.log.debug(f'Source: {oneof}.') + channel = getattr(request, oneof) + + if not isinstance(channel, Channel): + raise NotImplementedError(f'TODO: {type(channel)} not currently supported.') + + def on_channel_sdu(sdu): + async def handle_sdu(): + await self.sdu_queue.put(sdu) + + asyncio.create_task(handle_sdu()) + + l2cap_channel = self.get_l2cap_channel(channel) + if l2cap_channel is None: + raise ValueError('The channel in the request is not valid.') + + l2cap_channel.sink = on_channel_sdu + while sdu := await self.sdu_queue.get(): + # Retrieve the next SDU from the queue + self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}') + response = ReceiveResponse(data=sdu) + yield response + + @utils.rpc + async def Connect( + self, request: ConnectRequest, context: grpc.ServicerContext + ) -> ConnectResponse: + self.log.debug('Connect') + + if not request.connection: + raise ValueError('A valid connection field must be set') + + # find connection on device based on connection cookie value + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + connection = self.device.lookup_connection(connection_handle) + + if not connection: + raise ValueError('The connection specified is invalid.') + + oneof = request.WhichOneof('type') + self.log.debug(f'Channel request type: {oneof}.') + channel_type = getattr(request, oneof) + spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None + if isinstance(channel_type, CreditBasedChannelRequest): + spec = LeCreditBasedChannelSpec( + psm=channel_type.spsm, + max_credits=channel_type.initial_credit, + mtu=channel_type.mtu, + mps=channel_type.mps, + ) + else: + spec = ClassicChannelSpec( + psm=channel_type.psm, + mtu=channel_type.mtu, + ) + + try: + self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}') + l2cap_channel = await connection.create_l2cap_channel(spec=spec) + self.log.info(f'L2CAP channel: {l2cap_channel}') + except Exception as e: + l2cap_channel = None + self.log.exception(f'Connection failed: {e}') + + if not l2cap_channel: + return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD) + + channel = self.channel_to_proto(l2cap_channel) + return ConnectResponse(channel=channel) + + @utils.rpc + async def Disconnect( + self, request: DisconnectRequest, context: grpc.ServicerContext + ) -> DisconnectResponse: + try: + self.log.debug('Disconnect') + l2cap_channel = self.get_l2cap_channel(request.channel) + if not l2cap_channel: + self.log.warn('Disconnect: Unable to find the channel') + return DisconnectResponse(error=INVALID_CID_IN_REQUEST) + + await l2cap_channel.disconnect() + return DisconnectResponse(success=empty_pb2.Empty()) + except Exception as e: + self.log.exception(f'Disonnect failed: {e}') + return DisconnectResponse(error=COMMAND_NOT_UNDERSTOOD) + + @utils.rpc + async def Send( + self, request: SendRequest, context: grpc.ServicerContext + ) -> SendResponse: + self.log.debug('Send') + try: + oneof = request.WhichOneof('sink') + self.log.debug(f'Sink: {oneof}.') + channel = getattr(request, oneof) + + if not isinstance(channel, Channel): + raise NotImplementedError( + f'TODO: {type(channel)} not currently supported.' + ) + l2cap_channel = self.get_l2cap_channel(channel) + if not l2cap_channel: + return SendResponse(error=COMMAND_NOT_UNDERSTOOD) + if isinstance(l2cap_channel, ClassicChannel): + l2cap_channel.send_pdu(request.data) + else: + l2cap_channel.write(request.data) + return SendResponse(success=empty_pb2.Empty()) + except Exception as e: + self.log.exception(f'Disonnect failed: {e}') + return SendResponse(error=COMMAND_NOT_UNDERSTOOD) + + def get_l2cap_channel( + self, channel: Channel + ) -> Optional[Union[ClassicChannel, LeCreditBasedChannel]]: + parameters = self.get_channel_parameters(channel) + connection_handle = parameters.get('connection_handle', 0) + destination_cid = parameters.get('destination_cid', 0) + is_classic = parameters.get('is_classic', False) + self.log.debug( + f'get_l2cap_channel: Connection handle:{connection_handle}, cid:{destination_cid}' + ) + l2cap_channel: Optional[Union[ClassicChannel, LeCreditBasedChannel]] = None + if is_classic: + l2cap_channel = self.device.l2cap_channel_manager.find_channel( + connection_handle, destination_cid + ) + else: + l2cap_channel = self.device.l2cap_channel_manager.find_le_coc_channel( + connection_handle, destination_cid + ) + return l2cap_channel + + def channel_to_proto( + self, l2cap_channel: Union[ClassicChannel, LeCreditBasedChannel] + ) -> Channel: + parameters = { + "source_cid": l2cap_channel.source_cid, + "destination_cid": l2cap_channel.destination_cid, + "connection_handle": l2cap_channel.connection.handle, + "is_classic": True if isinstance(l2cap_channel, ClassicChannel) else False, + } + self.log.info(f'Channel parameters: {parameters}') + cookie = any_pb2.Any() + cookie.value = json.dumps(parameters).encode() + return Channel(cookie=cookie) + + def get_channel_parameters(self, channel: Channel) -> Dict['str', Any]: + cookie_value = channel.cookie.value.decode() + parameters = json.loads(cookie_value) + self.log.info(f'Channel parameters: {parameters}') + return parameters