# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio import logging import os from typing import ClassVar import click from prompt_toolkit.shortcuts import PromptSession from bumble import data_types, smp from bumble.a2dp import make_audio_sink_service_sdp_records from bumble.att import ( ATT_INSUFFICIENT_AUTHENTICATION_ERROR, ATT_INSUFFICIENT_ENCRYPTION_ERROR, ATT_Error, ) from bumble.colors import color from bumble.core import ( UUID, AdvertisingData, Appearance, DataType, PhysicalTransport, ProtocolError, ) from bumble.device import Connection, Device, Peer from bumble.gatt import ( GATT_DEVICE_NAME_CHARACTERISTIC, GATT_GENERIC_ACCESS_SERVICE, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_SERVICE, Characteristic, Service, ) from bumble.hci import OwnAddressType from bumble.keys import JsonKeyStore from bumble.pairing import OobData, PairingConfig, PairingDelegate from bumble.smp import OobContext, OobLegacyContext from bumble.transport import open_transport from bumble.utils import AsyncRunner # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- POST_PAIRING_DELAY = 1 # ----------------------------------------------------------------------------- class Waiter: instance: ClassVar[Waiter | None] = None def __init__(self, linger=False): self.done = asyncio.get_running_loop().create_future() self.linger = linger def terminate(self): if not self.linger and not self.done.done: self.done.set_result(None) async def wait_until_terminated(self): return await self.done # ----------------------------------------------------------------------------- class Delegate(PairingDelegate): def __init__(self, mode, connection, capability_string, do_prompt): super().__init__( io_capability={ 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, 'none': PairingDelegate.NO_OUTPUT_NO_INPUT, }[capability_string.lower()] ) self.mode = mode self.peer = Peer(connection) self.peer_name = None self.do_prompt = do_prompt def print(self, message): print(color(message, 'yellow')) async def prompt(self, message): # Wait a bit to allow some of the log lines to print before we prompt await asyncio.sleep(1) session = PromptSession(message) response = await session.prompt_async() return response.lower().strip() async def update_peer_name(self): if self.peer_name is not None: # We already asked the peer return # Try to get the peer's name if self.peer: peer_name = await get_peer_name(self.peer, self.mode) self.peer_name = f'{peer_name or ""} [{self.peer.connection.peer_address}]' else: self.peer_name = '[?]' async def accept(self): if self.do_prompt: await self.update_peer_name() # Prompt for acceptance self.print('###-----------------------------------') self.print(f'### Pairing request from {self.peer_name}') self.print('###-----------------------------------') while True: response = await self.prompt('>>> Accept? ') if response == 'yes': return True if response == 'no': return False # Accept silently return True async def compare_numbers(self, number, digits): await self.update_peer_name() # Prompt for a numeric comparison self.print('###-----------------------------------') self.print(f'### Pairing with {self.peer_name}') self.print('###-----------------------------------') while True: response = await self.prompt( f'>>> Does the other device display {number:0{digits}}? ' ) if response == 'yes': return True if response == 'no': return False async def get_number(self): await self.update_peer_name() # Prompt for a PIN while True: try: self.print('###-----------------------------------') self.print(f'### Pairing with {self.peer_name}') self.print('###-----------------------------------') return int(await self.prompt('>>> Enter PIN: ')) except ValueError: pass async def display_number(self, number, digits): await self.update_peer_name() # Display a PIN code self.print('###-----------------------------------') self.print(f'### Pairing with {self.peer_name}') self.print(f'### PIN: {number:0{digits}}') self.print('###-----------------------------------') async def get_string(self, max_length: int): await self.update_peer_name() # Prompt a PIN (for legacy pairing in classic) self.print('###-----------------------------------') self.print(f'### Pairing with {self.peer_name}') self.print('###-----------------------------------') count = 0 while True: response = await self.prompt('>>> Enter PIN (1-6 chars):') if len(response) == 0: count += 1 if count > 3: self.print('too many tries, stopping the pairing') return None self.print('no PIN was entered, try again') continue return response # ----------------------------------------------------------------------------- async def get_peer_name(peer, mode): if peer.connection.transport == PhysicalTransport.BR_EDR: return await peer.request_name() # Try to get the peer name from GATT services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE) if not services: return None values = await peer.read_characteristics_by_uuid( GATT_DEVICE_NAME_CHARACTERISTIC, services[0] ) if values: return values[0].decode('utf-8') return None # ----------------------------------------------------------------------------- AUTHENTICATION_ERROR_RETURNED = [False, False] def read_with_error(connection): if not connection.is_encrypted: raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) if AUTHENTICATION_ERROR_RETURNED[0]: return bytes([1]) AUTHENTICATION_ERROR_RETURNED[0] = True raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) # ----------------------------------------------------------------------------- def sdp_records(): service_record_handle = 0x00010001 return { service_record_handle: make_audio_sink_service_sdp_records( service_record_handle ) } # ----------------------------------------------------------------------------- def on_connection(connection, request): print(color(f'<<< Connection: {connection}', 'green')) # Listen for pairing events connection.on(connection.EVENT_PAIRING_START, on_pairing_start) connection.on(connection.EVENT_PAIRING, lambda keys: on_pairing(connection, keys)) connection.on( connection.EVENT_CLASSIC_PAIRING, lambda: on_classic_pairing(connection) ) connection.on( connection.EVENT_PAIRING_FAILURE, lambda reason: on_pairing_failure(connection, reason), ) # Listen for encryption changes connection.on( connection.EVENT_CONNECTION_ENCRYPTION_CHANGE, lambda: on_connection_encryption_change(connection), ) # Request pairing if needed if request: print(color('>>> Requesting pairing', 'green')) connection.request_pairing() # ----------------------------------------------------------------------------- def on_connection_encryption_change(connection): print(color('@@@-----------------------------------', 'blue')) print( color( f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue', ) ) print(color('@@@-----------------------------------', 'blue')) # ----------------------------------------------------------------------------- def on_pairing_start(): print(color('***-----------------------------------', 'magenta')) print(color('*** Pairing starting', 'magenta')) print(color('***-----------------------------------', 'magenta')) # ----------------------------------------------------------------------------- @AsyncRunner.run_in_task() async def on_pairing(connection, keys): print(color('***-----------------------------------', 'cyan')) print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan')) keys.print(prefix=color('*** ', 'cyan')) print(color('***-----------------------------------', 'cyan')) await asyncio.sleep(POST_PAIRING_DELAY) await connection.disconnect() Waiter.instance.terminate() # ----------------------------------------------------------------------------- @AsyncRunner.run_in_task() async def on_classic_pairing(connection): print(color('***-----------------------------------', 'cyan')) print( color( f'*** Paired [Classic]! (peer identity={connection.peer_address})', 'cyan' ) ) print(color('***-----------------------------------', 'cyan')) await asyncio.sleep(POST_PAIRING_DELAY) Waiter.instance.terminate() # ----------------------------------------------------------------------------- @AsyncRunner.run_in_task() async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode): print(color('***-----------------------------------', 'red')) print(color(f'*** Pairing failed: {reason.name}', 'red')) print(color('***-----------------------------------', 'red')) await connection.disconnect() if Waiter.instance: Waiter.instance.terminate() # ----------------------------------------------------------------------------- async def pair( mode: str, sc: bool, mitm: bool, bond: bool, ctkd: bool, advertising_address: str, identity_address: str, linger: bool, io: str, oob: str, prompt: bool, request: bool, print_keys: bool, keystore_file: str, advertise_service_uuids: str, advertise_appearance: str, device_config: str, hci_transport: str, address_or_name: str, ): Waiter.instance = Waiter(linger=linger) print('<<< connecting to HCI...') async with await open_transport(hci_transport) as (hci_source, hci_sink): print('<<< connected') # Create a device to manage the host device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) # Expose a GATT characteristic that can be used to trigger pairing by # responding with an authentication error when read if mode in ('le', 'dual'): device.add_service( Service( GATT_HEART_RATE_SERVICE, [ Characteristic( GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, Characteristic.Properties.READ, Characteristic.READ_REQUIRES_AUTHENTICATION, bytes(1), ) ], ) ) # LE and Classic support if mode in ('classic', 'dual'): device.classic_enabled = True device.classic_smp_enabled = ctkd if mode in ('le', 'dual'): device.le_enabled = True if mode == 'dual': device.le_simultaneous_enabled = True # Setup SDP if mode in ('classic', 'dual'): device.sdp_service_records = sdp_records() # Get things going await device.power_on() # Set a custom keystore if specified on the command line if keystore_file: device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) # Print the existing keys before pairing if print_keys and device.keystore: print(color('@@@-----------------------------------', 'blue')) print(color('@@@ Pairing Keys:', 'blue')) await device.keystore.print(prefix=color('@@@ ', 'blue')) print(color('@@@-----------------------------------', 'blue')) # Create an OOB context if needed if oob: our_oob_context = OobContext() legacy_context: OobLegacyContext | None if oob == '-': shared_data = None legacy_context = OobLegacyContext() else: oob_data = OobData.from_ad( AdvertisingData.from_bytes(bytes.fromhex(oob)) ) shared_data = oob_data.shared_data legacy_context = oob_data.legacy_context if legacy_context is None and not sc: print(color('OOB pairing in legacy mode requires TK', 'red')) return oob_contexts = PairingConfig.OobConfig( our_context=our_oob_context, peer_data=shared_data, legacy_context=legacy_context, ) print(color('@@@-----------------------------------', 'yellow')) print(color('@@@ OOB Data:', 'yellow')) if shared_data is None: oob_data = OobData( address=device.random_address, shared_data=our_oob_context.share(), legacy_context=(None if sc else legacy_context), ) print( color( f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}', 'yellow', ) ) if legacy_context: print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow')) print(color('@@@-----------------------------------', 'yellow')) else: oob_contexts = None # Set up a pairing config factory if identity_address == 'public': identity_address_type = PairingConfig.AddressType.PUBLIC elif identity_address == 'random': identity_address_type = PairingConfig.AddressType.RANDOM else: identity_address_type = None device.pairing_config_factory = lambda connection: PairingConfig( sc=sc, mitm=mitm, bonding=bond, oob=oob_contexts, identity_address_type=identity_address_type, delegate=Delegate(mode, connection, io, prompt), ) # Connect to a peer or wait for a connection device.on('connection', lambda connection: on_connection(connection, request)) if address_or_name is not None: print(color(f'=== Connecting to {address_or_name}...', 'green')) connection = await device.connect( address_or_name, transport=( PhysicalTransport.LE if mode == 'le' else PhysicalTransport.BR_EDR ), ) if not request: try: if mode == 'le': await connection.pair() else: await connection.authenticate() except ProtocolError as error: print(color(f'Pairing failed: {error}', 'red')) else: if mode in ('le', 'dual'): # Advertise so that peers can find us and connect. # Include the heart rate service UUID in the advertisement data # so that devices like iPhones can show this device in their # Bluetooth selector. service_uuids_16 = [] service_uuids_32 = [] service_uuids_128 = [] if advertise_service_uuids: for uuid in advertise_service_uuids: uuid = uuid.replace("-", "") if len(uuid) == 4: service_uuids_16.append(UUID(uuid)) elif len(uuid) == 8: service_uuids_32.append(UUID(uuid)) elif len(uuid) == 32: service_uuids_128.append(UUID(uuid)) else: print(color('Invalid UUID format', 'red')) return else: service_uuids_16.append(GATT_HEART_RATE_SERVICE) flags = AdvertisingData.Flags.LE_LIMITED_DISCOVERABLE_MODE if mode == 'le': flags |= AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED if mode == 'dual': flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE advertising_data_types: list[DataType] = [ data_types.Flags(flags), data_types.CompleteLocalName('Bumble'), ] if service_uuids_16: advertising_data_types.append( data_types.IncompleteListOf16BitServiceUUIDs(service_uuids_16) ) if service_uuids_32: advertising_data_types.append( data_types.IncompleteListOf32BitServiceUUIDs(service_uuids_32) ) if service_uuids_128: advertising_data_types.append( data_types.IncompleteListOf128BitServiceUUIDs(service_uuids_128) ) if advertise_appearance: advertise_appearance = advertise_appearance.upper() try: appearance = data_types.Appearance.from_int( int(advertise_appearance) ) except ValueError: category, subcategory = advertise_appearance.split('/') try: category_enum = Appearance.Category[category] except ValueError: print( color(f'Invalid appearance category {category}', 'red') ) return subcategory_class = Appearance.SUBCATEGORY_CLASSES[ category_enum ] try: subcategory_enum = subcategory_class[subcategory] except ValueError: print(color(f'Invalid subcategory {subcategory}', 'red')) return appearance = data_types.Appearance( category_enum, subcategory_enum ) advertising_data_types.append(appearance) device.advertising_data = bytes(AdvertisingData(advertising_data_types)) await device.start_advertising( auto_restart=True, own_address_type=( OwnAddressType.PUBLIC if advertising_address == 'public' else OwnAddressType.RANDOM ), ) if mode in ('classic', 'dual'): # Become discoverable and connectable await device.set_discoverable(True) await device.set_connectable(True) print(color('Ready for connections on', 'blue'), device.public_address) # Run until the user asks to exit await Waiter.instance.wait_until_terminated() # ----------------------------------------------------------------------------- class LogHandler(logging.Handler): def __init__(self): super().__init__() self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s')) def emit(self, record): message = self.format(record) print(message) # ----------------------------------------------------------------------------- @click.command() @click.option( '--mode', type=click.Choice(['le', 'classic', 'dual']), default='le', show_default=True, ) @click.option( '--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True, ) @click.option( '--mitm', type=bool, default=True, help='Request MITM protection', show_default=True ) @click.option( '--bond', type=bool, default=True, help='Enable bonding', show_default=True ) @click.option( '--ctkd', type=bool, default=True, help='Enable CTKD', show_default=True, ) @click.option( '--advertising-address', type=click.Choice(['random', 'public']), ) @click.option( '--identity-address', type=click.Choice(['random', 'public']), ) @click.option('--linger', default=False, is_flag=True, help='Linger after pairing') @click.option( '--io', type=click.Choice( ['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none'] ), default='display+keyboard', show_default=True, ) @click.option( '--oob', metavar='', help=( 'Use OOB pairing with this data from the peer ' '(use "-" to enable OOB without peer data)' ), ) @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') @click.option( '--request', is_flag=True, help='Request that the connecting peer initiate pairing' ) @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option( '--keystore-file', metavar='FILENAME', help='File in which to store the pairing keys', ) @click.option( '--advertise-service-uuid', metavar="UUID", multiple=True, help="Advertise a GATT service UUID (may be specified more than once)", ) @click.option( '--advertise-appearance', metavar='APPEARANCE', help='Advertise an Appearance ID (int value or string)', ) @click.argument('device-config') @click.argument('hci_transport') @click.argument('address-or-name', required=False) def main( mode: str, sc: bool, mitm: bool, bond: bool, ctkd: bool, advertising_address: str, identity_address: str, linger: bool, io: str, oob: str, prompt: bool, request: bool, print_keys: bool, keystore_file: str, advertise_service_uuid: str, advertise_appearance: str, device_config: str, hci_transport: str, address_or_name: str, ): # Setup logging log_handler = LogHandler() root_logger = logging.getLogger() root_logger.addHandler(log_handler) root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) # Pair asyncio.run( pair( mode, sc, mitm, bond, ctkd, advertising_address, identity_address, linger, io, oob, prompt, request, print_keys, keystore_file, advertise_service_uuid, advertise_appearance, device_config, hci_transport, address_or_name, ) ) # ----------------------------------------------------------------------------- if __name__ == '__main__': main()