diff --git a/apps/pandora_server.py b/apps/pandora_server.py new file mode 100644 index 0000000..5f92309 --- /dev/null +++ b/apps/pandora_server.py @@ -0,0 +1,30 @@ +import asyncio +import click +import logging + +from bumble.pandora import PandoraDevice, serve + +BUMBLE_SERVER_GRPC_PORT = 7999 +ROOTCANAL_PORT_CUTTLEFISH = 7300 + + +@click.command() +@click.option('--grpc-port', help='gRPC port to serve', default=BUMBLE_SERVER_GRPC_PORT) +@click.option( + '--rootcanal-port', help='Rootcanal TCP port', default=ROOTCANAL_PORT_CUTTLEFISH +) +@click.option( + '--transport', + help='HCI transport', + default=f'tcp-client:127.0.0.1:', +) +def main(grpc_port: int, rootcanal_port: int, transport: str) -> None: + if '' in transport: + transport = transport.replace('', str(rootcanal_port)) + device = PandoraDevice({'transport': transport}) + logging.basicConfig(level=logging.DEBUG) + asyncio.run(serve(device, port=grpc_port)) + + +if __name__ == '__main__': + main() # pylint: disable=no-value-for-parameter diff --git a/bumble/pandora/__init__.py b/bumble/pandora/__init__.py new file mode 100644 index 0000000..e02f54a --- /dev/null +++ b/bumble/pandora/__init__.py @@ -0,0 +1,105 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Bumble Pandora server. +This module implement the Pandora Bluetooth test APIs for the Bumble stack. +""" + +__version__ = "0.0.1" + +import grpc +import grpc.aio + +from .config import Config +from .device import PandoraDevice +from .host import HostService +from .security import SecurityService, SecurityStorageService +from pandora.host_grpc_aio import add_HostServicer_to_server +from pandora.security_grpc_aio import ( + add_SecurityServicer_to_server, + add_SecurityStorageServicer_to_server, +) +from typing import Callable, List, Optional + +# public symbols +__all__ = [ + 'register_servicer_hook', + 'serve', + 'Config', + 'PandoraDevice', +] + + +# Add servicers hooks. +_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = [] + + +def register_servicer_hook( + hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None] +) -> None: + _SERVICERS_HOOKS.append(hook) + + +async def serve( + bumble: PandoraDevice, + config: Config = Config(), + grpc_server: Optional[grpc.aio.Server] = None, + port: int = 0, +) -> None: + # initialize a gRPC server if not provided. + server = grpc_server if grpc_server is not None else grpc.aio.server() + port = server.add_insecure_port(f'localhost:{port}') + + try: + while True: + # load server config from dict. + config.load_from_dict(bumble.config.get('server', {})) + + # add Pandora services to the gRPC server. + add_HostServicer_to_server( + HostService(server, bumble.device, config), server + ) + add_SecurityServicer_to_server( + SecurityService(bumble.device, config), server + ) + add_SecurityStorageServicer_to_server( + SecurityStorageService(bumble.device, config), server + ) + + # call hooks if any. + for hook in _SERVICERS_HOOKS: + hook(bumble, config, server) + + # open device. + await bumble.open() + try: + # Pandora require classic devices to be discoverable & connectable. + if bumble.device.classic_enabled: + await bumble.device.set_discoverable(True) + await bumble.device.set_connectable(True) + + # start & serve gRPC server. + await server.start() + await server.wait_for_termination() + finally: + # close device. + await bumble.close() + + # re-initialize the gRPC server. + server = grpc.aio.server() + server.add_insecure_port(f'localhost:{port}') + finally: + # stop server. + await server.stop(None) diff --git a/bumble/pandora/config.py b/bumble/pandora/config.py new file mode 100644 index 0000000..5edba55 --- /dev/null +++ b/bumble/pandora/config.py @@ -0,0 +1,48 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bumble.pairing import PairingDelegate +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class Config: + io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT + pairing_sc_enable: bool = True + pairing_mitm_enable: bool = True + pairing_bonding_enable: bool = True + smp_local_initiator_key_distribution: PairingDelegate.KeyDistribution = ( + PairingDelegate.DEFAULT_KEY_DISTRIBUTION + ) + smp_local_responder_key_distribution: PairingDelegate.KeyDistribution = ( + PairingDelegate.DEFAULT_KEY_DISTRIBUTION + ) + + def load_from_dict(self, config: Dict[str, Any]) -> None: + io_capability_name: str = config.get( + 'io_capability', 'no_output_no_input' + ).upper() + self.io_capability = getattr(PairingDelegate, io_capability_name) + self.pairing_sc_enable = config.get('pairing_sc_enable', True) + self.pairing_mitm_enable = config.get('pairing_mitm_enable', True) + self.pairing_bonding_enable = config.get('pairing_bonding_enable', True) + self.smp_local_initiator_key_distribution = config.get( + 'smp_local_initiator_key_distribution', + PairingDelegate.DEFAULT_KEY_DISTRIBUTION, + ) + self.smp_local_responder_key_distribution = config.get( + 'smp_local_responder_key_distribution', + PairingDelegate.DEFAULT_KEY_DISTRIBUTION, + ) diff --git a/bumble/pandora/device.py b/bumble/pandora/device.py new file mode 100644 index 0000000..a4403b6 --- /dev/null +++ b/bumble/pandora/device.py @@ -0,0 +1,157 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic & dependency free Bumble (reference) device.""" + +from bumble import transport +from bumble.core import ( + BT_GENERIC_AUDIO_SERVICE, + BT_HANDSFREE_SERVICE, + BT_L2CAP_PROTOCOL_ID, + BT_RFCOMM_PROTOCOL_ID, +) +from bumble.device import Device, DeviceConfiguration +from bumble.host import Host +from bumble.sdp import ( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement, + ServiceAttribute, +) +from typing import Any, Dict, List, Optional + + +class PandoraDevice: + """ + Small wrapper around a Bumble device and it's HCI transport. + Notes: + - The Bumble device is idle by default. + - Repetitive calls to `open`/`close` will result on new Bumble device instances. + """ + + # Bumble device instance & configuration. + device: Device + config: Dict[str, Any] + + # HCI transport name & instance. + _hci_name: str + _hci: Optional[transport.Transport] # type: ignore[name-defined] + + def __init__(self, config: Dict[str, Any]) -> None: + self.config = config + self.device = _make_device(config) + self._hci_name = config.get('transport', '') + self._hci = None + + @property + def idle(self) -> bool: + return self._hci is None + + async def open(self) -> None: + if self._hci is not None: + return + + # open HCI transport & set device host. + self._hci = await transport.open_transport(self._hci_name) + self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call] + + # power-on. + await self.device.power_on() + + async def close(self) -> None: + if self._hci is None: + return + + # flush & re-initialize device. + await self.device.host.flush() + self.device.host = None # type: ignore[assignment] + self.device = _make_device(self.config) + + # close HCI transport. + await self._hci.close() + self._hci = None + + async def reset(self) -> None: + await self.close() + await self.open() + + def info(self) -> Optional[Dict[str, str]]: + return { + 'public_bd_address': str(self.device.public_address), + 'random_address': str(self.device.random_address), + } + + +def _make_device(config: Dict[str, Any]) -> Device: + """Initialize an idle Bumble device instance.""" + + # initialize bumble device. + device_config = DeviceConfiguration() + device_config.load_from_dict(config) + device = Device(config=device_config, host=None) + + # Add fake a2dp service to avoid Android disconnect + device.sdp_service_records = _make_sdp_records(1) + + return device + + +# TODO(b/267540823): remove when Pandora A2dp is supported +def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]: + return { + 0x00010001: [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(0x00010001), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_HANDSFREE_SERVICE), + DataElement.uuid(BT_GENERIC_AUDIO_SERVICE), + ] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), + DataElement.sequence( + [ + DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), + DataElement.unsigned_integer_8(rfcomm_channel), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_HANDSFREE_SERVICE), + DataElement.unsigned_integer_16(0x0105), + ] + ) + ] + ), + ), + ] + } diff --git a/bumble/pandora/host.py b/bumble/pandora/host.py new file mode 100644 index 0000000..63b295d --- /dev/null +++ b/bumble/pandora/host.py @@ -0,0 +1,856 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import bumble.device +import grpc +import grpc.aio +import logging +import struct + +from . import utils +from .config import Config +from bumble.core import ( + BT_BR_EDR_TRANSPORT, + BT_LE_TRANSPORT, + BT_PERIPHERAL_ROLE, + UUID, + AdvertisingData, + ConnectionError, +) +from bumble.device import ( + DEVICE_DEFAULT_SCAN_INTERVAL, + DEVICE_DEFAULT_SCAN_WINDOW, + Advertisement, + AdvertisingType, + Device, +) +from bumble.gatt import Service +from bumble.hci import ( + HCI_CONNECTION_ALREADY_EXISTS_ERROR, + HCI_PAGE_TIMEOUT_ERROR, + HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, + Address, +) +from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error +from pandora.host_grpc_aio import HostServicer +from pandora.host_pb2 import ( + NOT_CONNECTABLE, + NOT_DISCOVERABLE, + PRIMARY_1M, + PRIMARY_CODED, + SECONDARY_1M, + SECONDARY_2M, + SECONDARY_CODED, + SECONDARY_NONE, + AdvertiseRequest, + AdvertiseResponse, + Connection, + ConnectLERequest, + ConnectLEResponse, + ConnectRequest, + ConnectResponse, + DataTypes, + DisconnectRequest, + InquiryResponse, + PrimaryPhy, + ReadLocalAddressResponse, + ScanningResponse, + ScanRequest, + SecondaryPhy, + SetConnectabilityModeRequest, + SetDiscoverabilityModeRequest, + WaitConnectionRequest, + WaitConnectionResponse, + WaitDisconnectionRequest, +) +from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast + +PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = { + # Default value reported by Bumble for legacy Advertising reports. + # FIXME(uael): `None` might be a better value, but Bumble need to change accordingly. + 0: PRIMARY_1M, + 1: PRIMARY_1M, + 3: PRIMARY_CODED, +} + +SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = { + 0: SECONDARY_NONE, + 1: SECONDARY_1M, + 2: SECONDARY_2M, + 3: SECONDARY_CODED, +} + + +class HostService(HostServicer): + waited_connections: Set[int] + + def __init__( + self, grpc_server: grpc.aio.Server, device: Device, config: Config + ) -> None: + self.log = utils.BumbleServerLoggerAdapter( + logging.getLogger(), {'service_name': 'Host', 'device': device} + ) + self.grpc_server = grpc_server + self.device = device + self.config = config + self.waited_connections = set() + + @utils.rpc + async def FactoryReset( + self, request: empty_pb2.Empty, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + self.log.info('FactoryReset') + + # delete all bonds + if self.device.keystore is not None: + await self.device.keystore.delete_all() + + # trigger gRCP server stop then return + asyncio.create_task(self.grpc_server.stop(None)) + return empty_pb2.Empty() + + @utils.rpc + async def Reset( + self, request: empty_pb2.Empty, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + self.log.info('Reset') + + # clear service. + self.waited_connections.clear() + + # (re) power device on + await self.device.power_on() + return empty_pb2.Empty() + + @utils.rpc + async def ReadLocalAddress( + self, request: empty_pb2.Empty, context: grpc.ServicerContext + ) -> ReadLocalAddressResponse: + self.log.info('ReadLocalAddress') + return ReadLocalAddressResponse( + address=bytes(reversed(bytes(self.device.public_address))) + ) + + @utils.rpc + async def Connect( + self, request: ConnectRequest, context: grpc.ServicerContext + ) -> ConnectResponse: + # Need to reverse bytes order since Bumble Address is using MSB. + address = Address( + bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS + ) + self.log.info(f"Connect to {address}") + + try: + connection = await self.device.connect( + address, transport=BT_BR_EDR_TRANSPORT + ) + except ConnectionError as e: + if e.error_code == HCI_PAGE_TIMEOUT_ERROR: + self.log.warning(f"Peer not found: {e}") + return ConnectResponse(peer_not_found=empty_pb2.Empty()) + if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR: + self.log.warning(f"Connection already exists: {e}") + return ConnectResponse(connection_already_exists=empty_pb2.Empty()) + raise e + + self.log.info(f"Connect to {address} done (handle={connection.handle})") + + cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) + return ConnectResponse(connection=Connection(cookie=cookie)) + + @utils.rpc + async def WaitConnection( + self, request: WaitConnectionRequest, context: grpc.ServicerContext + ) -> WaitConnectionResponse: + if not request.address: + raise ValueError('Request address field must be set') + + # Need to reverse bytes order since Bumble Address is using MSB. + address = Address( + bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS + ) + if address in (Address.NIL, Address.ANY): + raise ValueError('Invalid address') + + self.log.info(f"WaitConnection from {address}...") + + connection = self.device.find_connection_by_bd_addr( + address, transport=BT_BR_EDR_TRANSPORT + ) + if connection and id(connection) in self.waited_connections: + # this connection was already returned: wait for a new one. + connection = None + + if not connection: + connection = await self.device.accept(address) + + # save connection has waited and respond. + self.waited_connections.add(id(connection)) + + self.log.info( + f"WaitConnection from {address} done (handle={connection.handle})" + ) + + cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) + return WaitConnectionResponse(connection=Connection(cookie=cookie)) + + @utils.rpc + async def ConnectLE( + self, request: ConnectLERequest, context: grpc.ServicerContext + ) -> ConnectLEResponse: + address = utils.address_from_request(request, request.WhichOneof("address")) + if address in (Address.NIL, Address.ANY): + raise ValueError('Invalid address') + + self.log.info(f"ConnectLE to {address}...") + + try: + connection = await self.device.connect( + address, + transport=BT_LE_TRANSPORT, + own_address_type=request.own_address_type, + ) + except ConnectionError as e: + if e.error_code == HCI_PAGE_TIMEOUT_ERROR: + self.log.warning(f"Peer not found: {e}") + return ConnectLEResponse(peer_not_found=empty_pb2.Empty()) + if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR: + self.log.warning(f"Connection already exists: {e}") + return ConnectLEResponse(connection_already_exists=empty_pb2.Empty()) + raise e + + self.log.info(f"ConnectLE to {address} done (handle={connection.handle})") + + cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) + return ConnectLEResponse(connection=Connection(cookie=cookie)) + + @utils.rpc + async def Disconnect( + self, request: DisconnectRequest, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + self.log.info(f"Disconnect: {connection_handle}") + + self.log.info("Disconnecting...") + if connection := self.device.lookup_connection(connection_handle): + await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR) + self.log.info("Disconnected") + + return empty_pb2.Empty() + + @utils.rpc + async def WaitDisconnection( + self, request: WaitDisconnectionRequest, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + self.log.info(f"WaitDisconnection: {connection_handle}") + + if connection := self.device.lookup_connection(connection_handle): + disconnection_future: asyncio.Future[ + None + ] = asyncio.get_running_loop().create_future() + + def on_disconnection(_: None) -> None: + disconnection_future.set_result(None) + + connection.on('disconnection', on_disconnection) + try: + await disconnection_future + self.log.info("Disconnected") + finally: + connection.remove_listener('disconnection', on_disconnection) # type: ignore + + return empty_pb2.Empty() + + @utils.rpc + async def Advertise( + self, request: AdvertiseRequest, context: grpc.ServicerContext + ) -> AsyncGenerator[AdvertiseResponse, None]: + if not request.legacy: + raise NotImplementedError( + "TODO: add support for extended advertising in Bumble" + ) + if request.interval: + raise NotImplementedError("TODO: add support for `request.interval`") + if request.interval_range: + raise NotImplementedError("TODO: add support for `request.interval_range`") + if request.primary_phy: + raise NotImplementedError("TODO: add support for `request.primary_phy`") + if request.secondary_phy: + raise NotImplementedError("TODO: add support for `request.secondary_phy`") + + if self.device.is_advertising: + raise NotImplementedError('TODO: add support for advertising sets') + + if data := request.data: + self.device.advertising_data = bytes(self.unpack_data_types(data)) + + if scan_response_data := request.scan_response_data: + self.device.scan_response_data = bytes( + self.unpack_data_types(scan_response_data) + ) + scannable = True + else: + scannable = False + + # Retrieve services data + for service in self.device.gatt_server.attributes: + if isinstance(service, Service) and ( + service_data := service.get_advertising_data() + ): + service_uuid = service.uuid.to_hex_str('-') + if ( + service_uuid in request.data.incomplete_service_class_uuids16 + or service_uuid in request.data.complete_service_class_uuids16 + or service_uuid in request.data.incomplete_service_class_uuids32 + or service_uuid in request.data.complete_service_class_uuids32 + or service_uuid + in request.data.incomplete_service_class_uuids128 + or service_uuid in request.data.complete_service_class_uuids128 + ): + self.device.advertising_data += service_data + if ( + service_uuid + in scan_response_data.incomplete_service_class_uuids16 + or service_uuid + in scan_response_data.complete_service_class_uuids16 + or service_uuid + in scan_response_data.incomplete_service_class_uuids32 + or service_uuid + in scan_response_data.complete_service_class_uuids32 + or service_uuid + in scan_response_data.incomplete_service_class_uuids128 + or service_uuid + in scan_response_data.complete_service_class_uuids128 + ): + self.device.scan_response_data += service_data + + target = None + if request.connectable and scannable: + advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE + elif scannable: + advertising_type = AdvertisingType.UNDIRECTED_SCANNABLE + else: + advertising_type = AdvertisingType.UNDIRECTED + else: + target = None + advertising_type = AdvertisingType.UNDIRECTED + + if request.target: + # Need to reverse bytes order since Bumble Address is using MSB. + target_bytes = bytes(reversed(request.target)) + if request.target_variant() == "public": + target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS) + advertising_type = ( + AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY + ) # FIXME: HIGH_DUTY ? + else: + target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS) + advertising_type = ( + AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY + ) # FIXME: HIGH_DUTY ? + + if request.connectable: + + def on_connection(connection: bumble.device.Connection) -> None: + if ( + connection.transport == BT_LE_TRANSPORT + and connection.role == BT_PERIPHERAL_ROLE + ): + pending_connection.set_result(connection) + + self.device.on('connection', on_connection) + + try: + while True: + if not self.device.is_advertising: + self.log.info('Advertise') + await self.device.start_advertising( + target=target, + advertising_type=advertising_type, + own_address_type=request.own_address_type, + ) + + if not request.connectable: + await asyncio.sleep(1) + continue + + pending_connection: asyncio.Future[ + bumble.device.Connection + ] = asyncio.get_running_loop().create_future() + + self.log.info('Wait for LE connection...') + connection = await pending_connection + + self.log.info( + f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})" + ) + + cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) + yield AdvertiseResponse(connection=Connection(cookie=cookie)) + + # wait a small delay before restarting the advertisement. + await asyncio.sleep(1) + finally: + if request.connectable: + self.device.remove_listener('connection', on_connection) # type: ignore + + try: + self.log.info('Stop advertising') + await self.device.abort_on('flush', self.device.stop_advertising()) + except: + pass + + @utils.rpc + async def Scan( + self, request: ScanRequest, context: grpc.ServicerContext + ) -> AsyncGenerator[ScanningResponse, None]: + # TODO: modify `start_scanning` to accept floats instead of int for ms values + if request.phys: + raise NotImplementedError("TODO: add support for `request.phys`") + + self.log.info('Scan') + + scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue() + handler = self.device.on('advertisement', scan_queue.put_nowait) + await self.device.start_scanning( + legacy=request.legacy, + active=not request.passive, + own_address_type=request.own_address_type, + scan_interval=int(request.interval) + if request.interval + else DEVICE_DEFAULT_SCAN_INTERVAL, + scan_window=int(request.window) + if request.window + else DEVICE_DEFAULT_SCAN_WINDOW, + ) + + try: + # TODO: add support for `direct_address` in Bumble + # TODO: add support for `periodic_advertising_interval` in Bumble + while adv := await scan_queue.get(): + sr = ScanningResponse( + legacy=adv.is_legacy, + connectable=adv.is_connectable, + scannable=adv.is_scannable, + truncated=adv.is_truncated, + sid=adv.sid, + primary_phy=PRIMARY_PHY_MAP[adv.primary_phy], + secondary_phy=SECONDARY_PHY_MAP[adv.secondary_phy], + tx_power=adv.tx_power, + rssi=adv.rssi, + data=self.pack_data_types(adv.data), + ) + + if adv.address.address_type == Address.PUBLIC_DEVICE_ADDRESS: + sr.public = bytes(reversed(bytes(adv.address))) + elif adv.address.address_type == Address.RANDOM_DEVICE_ADDRESS: + sr.random = bytes(reversed(bytes(adv.address))) + elif adv.address.address_type == Address.PUBLIC_IDENTITY_ADDRESS: + sr.public_identity = bytes(reversed(bytes(adv.address))) + else: + sr.random_static_identity = bytes(reversed(bytes(adv.address))) + + yield sr + + finally: + self.device.remove_listener('advertisement', handler) # type: ignore + try: + self.log.info('Stop scanning') + await self.device.abort_on('flush', self.device.stop_scanning()) + except: + pass + + @utils.rpc + async def Inquiry( + self, request: empty_pb2.Empty, context: grpc.ServicerContext + ) -> AsyncGenerator[InquiryResponse, None]: + self.log.info('Inquiry') + + inquiry_queue: asyncio.Queue[ + Optional[Tuple[Address, int, AdvertisingData, int]] + ] = asyncio.Queue() + complete_handler = self.device.on( + 'inquiry_complete', lambda: inquiry_queue.put_nowait(None) + ) + result_handler = self.device.on( # type: ignore + 'inquiry_result', + lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore + (address, class_of_device, eir_data, rssi) # type: ignore + ), + ) + + await self.device.start_discovery(auto_restart=False) + try: + while inquiry_result := await inquiry_queue.get(): + (address, class_of_device, eir_data, rssi) = inquiry_result + # FIXME: if needed, add support for `page_scan_repetition_mode` and `clock_offset` in Bumble + yield InquiryResponse( + address=bytes(reversed(bytes(address))), + class_of_device=class_of_device, + rssi=rssi, + data=self.pack_data_types(eir_data), + ) + + finally: + self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore + self.device.remove_listener('inquiry_result', result_handler) # type: ignore + try: + self.log.info('Stop inquiry') + await self.device.abort_on('flush', self.device.stop_discovery()) + except: + pass + + @utils.rpc + async def SetDiscoverabilityMode( + self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + self.log.info("SetDiscoverabilityMode") + await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE) + return empty_pb2.Empty() + + @utils.rpc + async def SetConnectabilityMode( + self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + self.log.info("SetConnectabilityMode") + await self.device.set_connectable(request.mode != NOT_CONNECTABLE) + return empty_pb2.Empty() + + def unpack_data_types(self, dt: DataTypes) -> AdvertisingData: + ad_structures: List[Tuple[int, bytes]] = [] + + uuids: List[str] + datas: Dict[str, bytes] + + def uuid128_from_str(uuid: str) -> bytes: + """Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX + to byte format.""" + return bytes(reversed(bytes.fromhex(uuid.replace('-', '')))) + + def uuid32_from_str(uuid: str) -> bytes: + """Decode a 32-bit uuid encoded as XXXXXXXX to byte format.""" + return bytes(reversed(bytes.fromhex(uuid))) + + def uuid16_from_str(uuid: str) -> bytes: + """Decode a 16-bit uuid encoded as XXXX to byte format.""" + return bytes(reversed(bytes.fromhex(uuid))) + + if uuids := dt.incomplete_service_class_uuids16: + ad_structures.append( + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid16_from_str(uuid) for uuid in uuids]), + ) + ) + if uuids := dt.complete_service_class_uuids16: + ad_structures.append( + ( + AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid16_from_str(uuid) for uuid in uuids]), + ) + ) + if uuids := dt.incomplete_service_class_uuids32: + ad_structures.append( + ( + AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid32_from_str(uuid) for uuid in uuids]), + ) + ) + if uuids := dt.complete_service_class_uuids32: + ad_structures.append( + ( + AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid32_from_str(uuid) for uuid in uuids]), + ) + ) + if uuids := dt.incomplete_service_class_uuids128: + ad_structures.append( + ( + AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid128_from_str(uuid) for uuid in uuids]), + ) + ) + if uuids := dt.complete_service_class_uuids128: + ad_structures.append( + ( + AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, + b''.join([uuid128_from_str(uuid) for uuid in uuids]), + ) + ) + if dt.HasField('include_shortened_local_name'): + ad_structures.append( + ( + AdvertisingData.SHORTENED_LOCAL_NAME, + bytes(self.device.name[:8], 'utf-8'), + ) + ) + elif dt.shortened_local_name: + ad_structures.append( + ( + AdvertisingData.SHORTENED_LOCAL_NAME, + bytes(dt.shortened_local_name, 'utf-8'), + ) + ) + if dt.HasField('include_complete_local_name'): + ad_structures.append( + (AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.device.name, 'utf-8')) + ) + elif dt.complete_local_name: + ad_structures.append( + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(dt.complete_local_name, 'utf-8'), + ) + ) + if dt.HasField('include_tx_power_level'): + raise ValueError('unsupported data type') + elif dt.tx_power_level: + ad_structures.append( + ( + AdvertisingData.TX_POWER_LEVEL, + bytes(struct.pack(' DataTypes: + dt = DataTypes() + uuids: List[UUID] + s: str + i: int + ij: Tuple[int, int] + uuid_data: Tuple[UUID, bytes] + data: bytes + + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS), + ): + dt.incomplete_service_class_uuids16.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS), + ): + dt.complete_service_class_uuids16.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS), + ): + dt.incomplete_service_class_uuids32.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS), + ): + dt.complete_service_class_uuids32.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS), + ): + dt.incomplete_service_class_uuids128.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS), + ): + dt.complete_service_class_uuids128.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if s := cast(str, ad.get(AdvertisingData.SHORTENED_LOCAL_NAME)): + dt.shortened_local_name = s + if s := cast(str, ad.get(AdvertisingData.COMPLETE_LOCAL_NAME)): + dt.complete_local_name = s + if i := cast(int, ad.get(AdvertisingData.TX_POWER_LEVEL)): + dt.tx_power_level = i + if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)): + dt.class_of_device = i + if ij := cast( + Tuple[int, int], + ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE), + ): + dt.peripheral_connection_interval_min = ij[0] + dt.peripheral_connection_interval_max = ij[1] + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS), + ): + dt.service_solicitation_uuids16.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS), + ): + dt.service_solicitation_uuids32.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuids := cast( + List[UUID], + ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS), + ): + dt.service_solicitation_uuids128.extend( + list(map(lambda x: x.to_hex_str('-'), uuids)) + ) + if uuid_data := cast( + Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID) + ): + dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1] + if uuid_data := cast( + Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID) + ): + dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1] + if uuid_data := cast( + Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID) + ): + dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1] + if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)): + dt.public_target_addresses.extend( + [data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))] + ) + if data := cast(bytes, ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True)): + dt.random_target_addresses.extend( + [data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))] + ) + if i := cast(int, ad.get(AdvertisingData.APPEARANCE)): + dt.appearance = i + if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)): + dt.advertising_interval = i + if s := cast(str, ad.get(AdvertisingData.URI)): + dt.uri = s + if data := cast(bytes, ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True)): + dt.le_supported_features = data + if data := cast( + bytes, ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True) + ): + dt.manufacturer_specific_data = data + + return dt diff --git a/bumble/pandora/py.typed b/bumble/pandora/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py new file mode 100644 index 0000000..fee1b7a --- /dev/null +++ b/bumble/pandora/security.py @@ -0,0 +1,529 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import grpc +import logging + +from . import utils +from .config import Config +from bumble import hci +from bumble.core import ( + BT_BR_EDR_TRANSPORT, + BT_LE_TRANSPORT, + BT_PERIPHERAL_ROLE, + ProtocolError, +) +from bumble.device import Connection as BumbleConnection, Device +from bumble.hci import HCI_Error +from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate +from contextlib import suppress +from google.protobuf import ( + any_pb2, + empty_pb2, + wrappers_pb2, +) # pytype: disable=pyi-error +from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error +from pandora.host_pb2 import Connection +from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer +from pandora.security_pb2 import ( + LE_LEVEL1, + LE_LEVEL2, + LE_LEVEL3, + LE_LEVEL4, + LEVEL0, + LEVEL1, + LEVEL2, + LEVEL3, + LEVEL4, + DeleteBondRequest, + IsBondedRequest, + LESecurityLevel, + PairingEvent, + PairingEventAnswer, + SecureRequest, + SecureResponse, + SecurityLevel, + WaitSecurityRequest, + WaitSecurityResponse, +) +from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union + + +class PairingDelegate(BasePairingDelegate): + def __init__( + self, + connection: BumbleConnection, + service: "SecurityService", + io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT, + local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, + local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, + ) -> None: + self.log = utils.BumbleServerLoggerAdapter( + logging.getLogger(), + {'service_name': 'Security', 'device': connection.device}, + ) + self.connection = connection + self.service = service + super().__init__( + io_capability, + local_initiator_key_distribution, + local_responder_key_distribution, + ) + + async def accept(self) -> bool: + return True + + def add_origin(self, ev: PairingEvent) -> PairingEvent: + if not self.connection.is_incomplete: + assert ev.connection + ev.connection.CopyFrom( + Connection( + cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big')) + ) + ) + else: + # In BR/EDR, connection may not be complete, + # use address instead + assert self.connection.transport == BT_BR_EDR_TRANSPORT + ev.address = bytes(reversed(bytes(self.connection.peer_address))) + + return ev + + async def confirm(self, auto: bool = False) -> bool: + self.log.info( + f"Pairing event: `just_works` (io_capability: {self.io_capability})" + ) + + if self.service.event_queue is None or self.service.event_answer is None: + return True + + event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty())) + self.service.event_queue.put_nowait(event) + answer = await anext(self.service.event_answer) # pytype: disable=name-error + assert answer.event == event + assert answer.answer_variant() == 'confirm' and answer.confirm is not None + return answer.confirm + + async def compare_numbers(self, number: int, digits: int = 6) -> bool: + self.log.info( + f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})" + ) + + if self.service.event_queue is None or self.service.event_answer is None: + raise RuntimeError('security: unhandled number comparison request') + + event = self.add_origin(PairingEvent(numeric_comparison=number)) + self.service.event_queue.put_nowait(event) + answer = await anext(self.service.event_answer) # pytype: disable=name-error + assert answer.event == event + assert answer.answer_variant() == 'confirm' and answer.confirm is not None + return answer.confirm + + async def get_number(self) -> Optional[int]: + self.log.info( + f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})" + ) + + if self.service.event_queue is None or self.service.event_answer is None: + raise RuntimeError('security: unhandled number request') + + event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty())) + self.service.event_queue.put_nowait(event) + answer = await anext(self.service.event_answer) # pytype: disable=name-error + assert answer.event == event + if answer.answer_variant() is None: + return None + assert answer.answer_variant() == 'passkey' + return answer.passkey + + async def get_string(self, max_length: int) -> Optional[str]: + self.log.info( + f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})" + ) + + if self.service.event_queue is None or self.service.event_answer is None: + raise RuntimeError('security: unhandled pin_code request') + + event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty())) + self.service.event_queue.put_nowait(event) + answer = await anext(self.service.event_answer) # pytype: disable=name-error + assert answer.event == event + if answer.answer_variant() is None: + return None + assert answer.answer_variant() == 'pin' + + if answer.pin is None: + return None + + pin = answer.pin.decode('utf-8') + if not pin or len(pin) > max_length: + raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes') + + return pin + + async def display_number(self, number: int, digits: int = 6) -> None: + if ( + self.connection.transport == BT_BR_EDR_TRANSPORT + and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY + ): + return + + self.log.info( + f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})" + ) + + if self.service.event_queue is None: + raise RuntimeError('security: unhandled number display request') + + event = self.add_origin(PairingEvent(passkey_entry_notification=number)) + self.service.event_queue.put_nowait(event) + + +BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = { + LEVEL0: lambda connection: True, + LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated, + LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated, + LEVEL3: lambda connection: connection.encryption != 0 + and connection.authenticated + and connection.link_key_type + in ( + hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, + hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, + ), + LEVEL4: lambda connection: connection.encryption + == hci.HCI_Encryption_Change_Event.AES_CCM + and connection.authenticated + and connection.link_key_type + == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, +} + +LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = { + LE_LEVEL1: lambda connection: True, + LE_LEVEL2: lambda connection: connection.encryption != 0, + LE_LEVEL3: lambda connection: connection.encryption != 0 + and connection.authenticated, + LE_LEVEL4: lambda connection: connection.encryption != 0 + and connection.authenticated + and connection.sc, +} + + +class SecurityService(SecurityServicer): + def __init__(self, device: Device, config: Config) -> None: + self.log = utils.BumbleServerLoggerAdapter( + logging.getLogger(), {'service_name': 'Security', 'device': device} + ) + self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None + self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None + self.device = device + self.config = config + + def pairing_config_factory(connection: BumbleConnection) -> PairingConfig: + return PairingConfig( + sc=config.pairing_sc_enable, + mitm=config.pairing_mitm_enable, + bonding=config.pairing_bonding_enable, + delegate=PairingDelegate( + connection, + self, + io_capability=config.io_capability, + local_initiator_key_distribution=config.smp_local_initiator_key_distribution, + local_responder_key_distribution=config.smp_local_responder_key_distribution, + ), + ) + + self.device.pairing_config_factory = pairing_config_factory + + @utils.rpc + async def OnPairing( + self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext + ) -> AsyncGenerator[PairingEvent, None]: + self.log.info('OnPairing') + + if self.event_queue is not None: + raise RuntimeError('already streaming pairing events') + + if len(self.device.connections): + raise RuntimeError( + 'the `OnPairing` method shall be initiated before establishing any connections.' + ) + + self.event_queue = asyncio.Queue() + self.event_answer = request + + try: + while event := await self.event_queue.get(): + yield event + + finally: + self.event_queue = None + self.event_answer = None + + @utils.rpc + async def Secure( + self, request: SecureRequest, context: grpc.ServicerContext + ) -> SecureResponse: + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + self.log.info(f"Secure: {connection_handle}") + + connection = self.device.lookup_connection(connection_handle) + assert connection + + oneof = request.WhichOneof('level') + level = getattr(request, oneof) + assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ + connection.transport + ] == oneof + + # security level already reached + if self.reached_security_level(connection, level): + return SecureResponse(success=empty_pb2.Empty()) + + # trigger pairing if needed + if self.need_pairing(connection, level): + try: + self.log.info('Pair...') + + if ( + connection.transport == BT_LE_TRANSPORT + and connection.role == BT_PERIPHERAL_ROLE + ): + wait_for_security: asyncio.Future[ + bool + ] = asyncio.get_running_loop().create_future() + connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore + connection.on("pairing_failure", wait_for_security.set_exception) + + connection.request_pairing() + + await wait_for_security + else: + await connection.pair() + + self.log.info('Paired') + except asyncio.CancelledError: + self.log.warning("Connection died during encryption") + return SecureResponse(connection_died=empty_pb2.Empty()) + except (HCI_Error, ProtocolError) as e: + self.log.warning(f"Pairing failure: {e}") + return SecureResponse(pairing_failure=empty_pb2.Empty()) + + # trigger authentication if needed + if self.need_authentication(connection, level): + try: + self.log.info('Authenticate...') + await connection.authenticate() + self.log.info('Authenticated') + except asyncio.CancelledError: + self.log.warning("Connection died during authentication") + return SecureResponse(connection_died=empty_pb2.Empty()) + except (HCI_Error, ProtocolError) as e: + self.log.warning(f"Authentication failure: {e}") + return SecureResponse(authentication_failure=empty_pb2.Empty()) + + # trigger encryption if needed + if self.need_encryption(connection, level): + try: + self.log.info('Encrypt...') + await connection.encrypt() + self.log.info('Encrypted') + except asyncio.CancelledError: + self.log.warning("Connection died during encryption") + return SecureResponse(connection_died=empty_pb2.Empty()) + except (HCI_Error, ProtocolError) as e: + self.log.warning(f"Encryption failure: {e}") + return SecureResponse(encryption_failure=empty_pb2.Empty()) + + # security level has been reached ? + if self.reached_security_level(connection, level): + return SecureResponse(success=empty_pb2.Empty()) + return SecureResponse(not_reached=empty_pb2.Empty()) + + @utils.rpc + async def WaitSecurity( + self, request: WaitSecurityRequest, context: grpc.ServicerContext + ) -> WaitSecurityResponse: + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + self.log.info(f"WaitSecurity: {connection_handle}") + + connection = self.device.lookup_connection(connection_handle) + assert connection + + assert request.level + level = request.level + assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ + connection.transport + ] == request.level_variant() + + wait_for_security: asyncio.Future[ + str + ] = asyncio.get_running_loop().create_future() + authenticate_task: Optional[asyncio.Future[None]] = None + + async def authenticate() -> None: + assert connection + if (encryption := connection.encryption) != 0: + self.log.debug('Disable encryption...') + try: + await connection.encrypt(enable=False) + except: + pass + self.log.debug('Disable encryption: done') + + self.log.debug('Authenticate...') + await connection.authenticate() + self.log.debug('Authenticate: done') + + if encryption != 0 and connection.encryption != encryption: + self.log.debug('Re-enable encryption...') + await connection.encrypt() + self.log.debug('Re-enable encryption: done') + + def set_failure(name: str) -> Callable[..., None]: + def wrapper(*args: Any) -> None: + self.log.info(f'Wait for security: error `{name}`: {args}') + wait_for_security.set_result(name) + + return wrapper + + def try_set_success(*_: Any) -> None: + assert connection + if self.reached_security_level(connection, level): + self.log.info('Wait for security: done') + wait_for_security.set_result('success') + + def on_encryption_change(*_: Any) -> None: + assert connection + if self.reached_security_level(connection, level): + self.log.info('Wait for security: done') + wait_for_security.set_result('success') + elif ( + connection.transport == BT_BR_EDR_TRANSPORT + and self.need_authentication(connection, level) + ): + nonlocal authenticate_task + if authenticate_task is None: + authenticate_task = asyncio.create_task(authenticate()) + + listeners: Dict[str, Callable[..., None]] = { + 'disconnection': set_failure('connection_died'), + 'pairing_failure': set_failure('pairing_failure'), + 'connection_authentication_failure': set_failure('authentication_failure'), + 'connection_encryption_failure': set_failure('encryption_failure'), + 'pairing': try_set_success, + 'connection_authentication': try_set_success, + 'connection_encryption_change': on_encryption_change, + } + + # register event handlers + for event, listener in listeners.items(): + connection.on(event, listener) + + # security level already reached + if self.reached_security_level(connection, level): + return WaitSecurityResponse(success=empty_pb2.Empty()) + + self.log.info('Wait for security...') + kwargs = {} + kwargs[await wait_for_security] = empty_pb2.Empty() + + # remove event handlers + for event, listener in listeners.items(): + connection.remove_listener(event, listener) # type: ignore + + # wait for `authenticate` to finish if any + if authenticate_task is not None: + self.log.info('Wait for authentication...') + try: + await authenticate_task # type: ignore + except: + pass + self.log.info('Authenticated') + + return WaitSecurityResponse(**kwargs) + + def reached_security_level( + self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] + ) -> bool: + self.log.debug( + str( + { + 'level': level, + 'encryption': connection.encryption, + 'authenticated': connection.authenticated, + 'sc': connection.sc, + 'link_key_type': connection.link_key_type, + } + ) + ) + + if isinstance(level, LESecurityLevel): + return LE_LEVEL_REACHED[level](connection) + + return BR_LEVEL_REACHED[level](connection) + + def need_pairing(self, connection: BumbleConnection, level: int) -> bool: + if connection.transport == BT_LE_TRANSPORT: + return level >= LE_LEVEL3 and not connection.authenticated + return False + + def need_authentication(self, connection: BumbleConnection, level: int) -> bool: + if connection.transport == BT_LE_TRANSPORT: + return False + if level == LEVEL2 and connection.encryption != 0: + return not connection.authenticated + return level >= LEVEL2 and not connection.authenticated + + def need_encryption(self, connection: BumbleConnection, level: int) -> bool: + # TODO(abel): need to support MITM + if connection.transport == BT_LE_TRANSPORT: + return level == LE_LEVEL2 and not connection.encryption + return level >= LEVEL2 and not connection.encryption + + +class SecurityStorageService(SecurityStorageServicer): + def __init__(self, device: Device, config: Config) -> None: + self.log = utils.BumbleServerLoggerAdapter( + logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device} + ) + self.device = device + self.config = config + + @utils.rpc + async def IsBonded( + self, request: IsBondedRequest, context: grpc.ServicerContext + ) -> wrappers_pb2.BoolValue: + address = utils.address_from_request(request, request.WhichOneof("address")) + self.log.info(f"IsBonded: {address}") + + if self.device.keystore is not None: + is_bonded = await self.device.keystore.get(str(address)) is not None + else: + is_bonded = False + + return BoolValue(value=is_bonded) + + @utils.rpc + async def DeleteBond( + self, request: DeleteBondRequest, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + address = utils.address_from_request(request, request.WhichOneof("address")) + self.log.info(f"DeleteBond: {address}") + + if self.device.keystore is not None: + with suppress(KeyError): + await self.device.keystore.delete(str(address)) + + return empty_pb2.Empty() diff --git a/bumble/pandora/utils.py b/bumble/pandora/utils.py new file mode 100644 index 0000000..c07a5bc --- /dev/null +++ b/bumble/pandora/utils.py @@ -0,0 +1,112 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import grpc +import inspect +import logging + +from bumble.device import Device +from bumble.hci import Address +from google.protobuf.message import Message # pytype: disable=pyi-error +from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple + +ADDRESS_TYPES: Dict[str, int] = { + "public": Address.PUBLIC_DEVICE_ADDRESS, + "random": Address.RANDOM_DEVICE_ADDRESS, + "public_identity": Address.PUBLIC_IDENTITY_ADDRESS, + "random_static_identity": Address.RANDOM_IDENTITY_ADDRESS, +} + + +def address_from_request(request: Message, field: Optional[str]) -> Address: + if field is None: + return Address.ANY + return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field]) + + +class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore + """Formats logs from the PandoraClient.""" + + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> Tuple[str, MutableMapping[str, Any]]: + assert self.extra + service_name = self.extra['service_name'] + assert isinstance(service_name, str) + device = self.extra['device'] + assert isinstance(device, Device) + addr_bytes = bytes( + reversed(bytes(device.public_address)) + ) # pytype: disable=attribute-error + addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]]) + return (f'[bumble.{service_name}:{addr}] {msg}', kwargs) + + +@contextlib.contextmanager +def exception_to_rpc_error( + context: grpc.ServicerContext, +) -> Generator[None, None, None]: + try: + yield None + except NotImplementedError as e: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore + context.set_details(str(e)) # type: ignore + except ValueError as e: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore + context.set_details(str(e)) # type: ignore + except RuntimeError as e: + context.set_code(grpc.StatusCode.ABORTED) # type: ignore + context.set_details(str(e)) # type: ignore + + +# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors. +def rpc(func: Any) -> Any: + @functools.wraps(func) + async def asyncgen_wrapper( + self: Any, request: Any, context: grpc.ServicerContext + ) -> Any: + with exception_to_rpc_error(context): + async for v in func(self, request, context): + yield v + + @functools.wraps(func) + async def async_wrapper( + self: Any, request: Any, context: grpc.ServicerContext + ) -> Any: + with exception_to_rpc_error(context): + return await func(self, request, context) + + @functools.wraps(func) + def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any: + with exception_to_rpc_error(context): + for v in func(self, request, context): + yield v + + @functools.wraps(func) + def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any: + with exception_to_rpc_error(context): + return func(self, request, context) + + if inspect.isasyncgenfunction(func): + return asyncgen_wrapper + + if inspect.iscoroutinefunction(func): + return async_wrapper + + if inspect.isgenerator(func): + return gen_wrapper + + return wrapper diff --git a/pyproject.toml b/pyproject.toml index f6abc31..8662723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ disable = [ "too-many-statements", ] +[tool.pylint.main] +ignore="pandora" # FIXME: pylint does not support stubs yet: + [tool.pylint.typecheck] signature-mutators="AsyncRunner.run_in_task" diff --git a/setup.cfg b/setup.cfg index 1644b28..45c7264 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ url = https://github.com/google/bumble [options] python_requires = >=3.8 -packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay +packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay, bumble.pandora package_dir = bumble = bumble bumble.apps = apps @@ -33,7 +33,7 @@ install_requires = appdirs >= 1.4 click >= 7.1.2; platform_system!='Emscripten' cryptography == 35; platform_system!='Emscripten' - grpcio >= 1.46; platform_system!='Emscripten' + grpcio == 1.51.1; platform_system!='Emscripten' libusb1 >= 2.0.1; platform_system!='Emscripten' libusb-package == 1.0.26.1; platform_system!='Emscripten' prompt_toolkit >= 3.0.16; platform_system!='Emscripten' @@ -45,6 +45,7 @@ install_requires = websockets >= 8.1; platform_system!='Emscripten' prettytable >= 3.6.0 humanize >= 4.6.0 + bt-test-interfaces >= 0.0.2 [options.entry_points] console_scripts = @@ -60,6 +61,7 @@ console_scripts = bumble-usb-probe = bumble.apps.usb_probe:main bumble-link-relay = bumble.apps.link_relay.link_relay:main bumble-bench = bumble.apps.bench:main + bumble-pandora-server = bumble.apps.pandora_server:main [options.package_data] * = py.typed, *.pyi