From f3b776c343e2c5e9bfc167cc3267e7234a84e8fb Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Thu, 25 Jul 2024 12:46:05 -0700 Subject: [PATCH] wip --- apps/auracast.py | 375 +++++++++++++--- apps/device_info.py | 45 +- apps/lea_unicast/app.py | 29 +- bumble/device.py | 159 +++++-- bumble/gatt.py | 7 +- bumble/gatt_client.py | 6 + bumble/hci.py | 38 +- bumble/profiles/ascs.py | 738 +++++++++++++++++++++++++++++++ bumble/profiles/bap.py | 878 +------------------------------------ bumble/profiles/bass.py | 374 +++++++++++++++- bumble/profiles/pacs.py | 206 +++++++++ bumble/transport/common.py | 4 +- tests/bap_test.py | 12 +- tests/bass_test.py | 145 ++++++ tests/import_test.py | 4 + 15 files changed, 1983 insertions(+), 1037 deletions(-) create mode 100644 bumble/profiles/ascs.py create mode 100644 bumble/profiles/pacs.py create mode 100644 tests/bass_test.py diff --git a/apps/auracast.py b/apps/auracast.py index 89f77a97..6e591eb0 100644 --- a/apps/auracast.py +++ b/apps/auracast.py @@ -17,10 +17,11 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import contextlib import dataclasses import logging import os -from typing import cast, Dict, Optional, Tuple +from typing import cast, Any, AsyncGenerator, Dict, Optional, Tuple import click import pyee @@ -32,6 +33,7 @@ import bumble.device import bumble.gatt import bumble.hci import bumble.profiles.bap +import bumble.profiles.bass import bumble.profiles.pbp import bumble.transport import bumble.utils @@ -46,14 +48,16 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast" -AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5") +AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast' +AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5') +AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0 +AURACAST_DEFAULT_ATT_MTU = 256 # ----------------------------------------------------------------------------- -# Discover Broadcasts +# Scan For Broadcasts # ----------------------------------------------------------------------------- -class BroadcastDiscoverer: +class BroadcastScanner(pyee.EventEmitter): @dataclasses.dataclass class Broadcast(pyee.EventEmitter): name: str @@ -79,22 +83,6 @@ class BroadcastDiscoverer: self.sync.on('periodic_advertisement', self.on_periodic_advertisement) self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement) - self.establishment_timeout_task = asyncio.create_task( - self.wait_for_establishment() - ) - - async def wait_for_establishment(self) -> None: - await asyncio.sleep(5.0) - if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING: - print( - color( - '!!! Periodic advertisement sync not established in time, ' - 'canceling', - 'red', - ) - ) - await self.sync.terminate() - def update(self, advertisement: bumble.device.Advertisement) -> None: self.rssi = advertisement.rssi for service_data in advertisement.data.get_all( @@ -139,6 +127,8 @@ class BroadcastDiscoverer: data, ) + self.emit('update') + def print(self) -> None: print( color('Broadcast:', 'yellow'), @@ -227,13 +217,12 @@ class BroadcastDiscoverer: ) def on_sync_establishment(self) -> None: - self.establishment_timeout_task.cancel() - self.emit('change') + self.emit('sync_establishment') def on_sync_loss(self) -> None: self.basic_audio_announcement = None self.biginfo = None - self.emit('change') + self.emit('sync_loss') def on_periodic_advertisement( self, advertisement: bumble.device.PeriodicAdvertisement @@ -268,37 +257,21 @@ class BroadcastDiscoverer: filter_duplicates: bool, sync_timeout: float, ): + super().__init__() self.device = device self.filter_duplicates = filter_duplicates self.sync_timeout = sync_timeout - self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {} - self.status_message = '' + self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {} device.on('advertisement', self.on_advertisement) - async def run(self) -> None: - self.status_message = color('Scanning...', 'green') + async def start(self) -> None: await self.device.start_scanning( active=False, filter_duplicates=False, ) - def refresh(self) -> None: - # Clear the screen from the top - print('\033[H') - print('\033[0J') - print('\033[H') - - # Print the status message - print(self.status_message) - print("==========================================") - - # Print all broadcasts - for broadcast in self.broadcasts.values(): - broadcast.print() - print('------------------------------------------') - - # Clear the screen to the bottom - print('\033[0J') + async def stop(self) -> None: + await self.device.stop_scanning() def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: if ( @@ -311,7 +284,6 @@ class BroadcastDiscoverer: if broadcast := self.broadcasts.get(advertisement.address): broadcast.update(advertisement) - self.refresh() return bumble.utils.AsyncRunner.spawn( @@ -331,46 +303,281 @@ class BroadcastDiscoverer: name, periodic_advertising_sync, ) - broadcast.on('change', self.refresh) broadcast.update(advertisement) self.broadcasts[advertisement.address] = broadcast periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast)) - self.status_message = color( - f'+Found {len(self.broadcasts)} broadcasts', 'green' - ) - self.refresh() + self.emit('new_broadcast', broadcast) def on_broadcast_loss(self, broadcast: Broadcast) -> None: del self.broadcasts[broadcast.sync.advertiser_address] bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate()) + self.emit('broadcast_loss', broadcast) + + +class PrintingBroadcastScanner: + def __init__( + self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float + ) -> None: + self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout) + self.scanner.on('new_broadcast', self.on_new_broadcast) + self.scanner.on('broadcast_loss', self.on_broadcast_loss) + self.scanner.on('update', self.refresh) + self.status_message = '' + + async def start(self) -> None: + self.status_message = color('Scanning...', 'green') + await self.scanner.start() + + def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None: self.status_message = color( - f'-Found {len(self.broadcasts)} broadcasts', 'green' + f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green' + ) + broadcast.on('change', self.refresh) + broadcast.on('update', self.refresh) + self.refresh() + + def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None: + self.status_message = color( + f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green' ) self.refresh() + def refresh(self) -> None: + # Clear the screen from the top + print('\033[H') + print('\033[0J') + print('\033[H') -async def run_discover_broadcasts( - filter_duplicates: bool, sync_timeout: float, transport: str -) -> None: + # Print the status message + print(self.status_message) + print("==========================================") + + # Print all broadcasts + for broadcast in self.scanner.broadcasts.values(): + broadcast.print() + print('------------------------------------------') + + # Clear the screen to the bottom + print('\033[0J') + + +@contextlib.asynccontextmanager +async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]: async with await bumble.transport.open_transport(transport) as ( hci_source, hci_sink, ): - device = bumble.device.Device.with_hci( - AURACAST_DEFAULT_DEVICE_NAME, - AURACAST_DEFAULT_DEVICE_ADDRESS, + device_config = bumble.device.DeviceConfiguration( + name=AURACAST_DEFAULT_DEVICE_NAME, + address=AURACAST_DEFAULT_DEVICE_ADDRESS, + keystore='JsonKeyStore', + ) + + device = bumble.device.Device.from_config_with_hci( + device_config, hci_source, hci_sink, ) await device.power_on() + yield device + + +async def find_broadcast_by_name( + device: bumble.device.Device, name: Optional[str] +) -> BroadcastScanner.Broadcast: + result = asyncio.get_running_loop().create_future() + + def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None: + if broadcast.basic_audio_announcement and not result.done(): + print(color('Broadcast basic audio announcement received', 'green')) + result.set_result(broadcast) + + def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None: + if name is None or broadcast.name == name: + print(color('Broadcast found:', 'green'), broadcast.name) + broadcast.on('change', lambda: on_broadcast_change(broadcast)) + return + + print(color(f'Skipping broadcast {broadcast.name}')) + + scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT) + scanner.on('new_broadcast', on_new_broadcast) + await scanner.start() + + broadcast = await result + await scanner.stop() + + return broadcast + + +async def run_scan( + filter_duplicates: bool, sync_timeout: float, transport: str +) -> None: + async with create_device(transport) as device: if not device.supports_le_periodic_advertising: print(color('Periodic advertising not supported', 'red')) return - discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout) - await discoverer.run() - await hci_source.terminated + scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout) + await scanner.start() + await asyncio.get_running_loop().create_future() + + +async def run_assist( + broadcast_name: Optional[str], + source_id: Optional[int], + command: str, + transport: str, + address: str, +) -> None: + async with create_device(transport) as device: + if not device.supports_le_periodic_advertising: + print(color('Periodic advertising not supported', 'red')) + return + + # Connect to the server + print(f'=== Connecting to {address}...') + connection = await device.connect(address) + peer = bumble.device.Peer(connection) + print(f'=== Connected to {peer}') + + print("+++ Encrypting connection...") + await peer.connection.encrypt() + print("+++ Connection encrypted") + + # Request a larger MTU + mtu = AURACAST_DEFAULT_ATT_MTU + print(color(f'$$$ Requesting MTU={mtu}', 'yellow')) + await peer.request_mtu(mtu) + + # Get the BASS service + bass = await peer.discover_service_and_create_proxy( + bumble.profiles.bass.BroadcastAudioScanServiceProxy + ) + + # Check that the service was found + if not bass: + print(color('!!! Broadcast Audio Scan Service not found', 'red')) + return + + # Subscribe to and read the broadcast receive state characteristics + for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states): + try: + await broadcast_receive_state.subscribe( + lambda value, i=i: print( + f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}" + ) + ) + except bumble.core.ProtocolError as error: + print( + color( + f'!!! Failed to subscribe to Broadcast Receive State characteristic:', + 'red', + ), + error, + ) + value = await broadcast_receive_state.read_value() + print( + f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}' + ) + + if command == 'monitor-state': + await peer.sustain() + return + + if command == 'add-source': + # Find the requested broadcast + await bass.remote_scan_started() + if broadcast_name: + print(color('Scanning for broadcast:', 'cyan'), broadcast_name) + else: + print(color('Scanning for any broadcast', 'cyan')) + broadcast = await find_broadcast_by_name(device, broadcast_name) + + if ( + broadcast.basic_audio_announcement is None + or not broadcast.basic_audio_announcement.subgroups + ): + print(color('No subgroups found', 'red')) + return + + # Add the source + print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address) + await bass.add_source( + broadcast.sync.advertiser_address, + broadcast.sync.sid, + broadcast.broadcast_audio_announcement.broadcast_id, + bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE, + 0xFFFF, + [ + bumble.profiles.bass.SubgroupInfo( + 0xFFFFFFFF, # bumble.profiles.bass.SubgroupInfo.ANY_BIS, + bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), + ) + ], + ) + + # Initiate a PA Sync Transfer + await broadcast.sync.transfer(peer.connection) + + # Notify the sink that we're done scanning. + await bass.remote_scan_stopped() + + await peer.sustain() + return + + if command == 'modify-source': + if source_id is None: + print(color('!!! modify-source requires --source-id')) + return + + # Modify the source + print( + color('Modifying source:', 'blue'), + source_id, + ) + await bass.modify_source( + source_id, + bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 0xFFFF, + [ + # bumble.profiles.bass.SubgroupInfo( + # 1, # bumble.profiles.bass.SubgroupInfo.ANY_BIS, + # bytes( + # broadcast.basic_audio_announcement.subgroups[0].metadata + # ), + # ) + ], + ) + await peer.sustain() + return + + if command == 'remove-source': + if source_id is None: + print(color('!!! remove-source requires --source-id')) + return + + # Remove the source + print(color('Removing source:', 'blue'), source_id) + await bass.remove_source(source_id) + await peer.sustain() + return + + print(color(f'!!! invalid command {command}')) + + +async def run_pair(transport: str, address: str) -> None: + async with create_device(transport) as device: + + # Connect to the server + print(f'=== Connecting to {address}...') + async with device.connect_as_gatt(address) as peer: + print(f'=== Connected to {peer}') + + print("+++ Initiating pairing...") + await peer.connection.pair() + print("+++ Paired") # ----------------------------------------------------------------------------- @@ -384,7 +591,7 @@ def auracast( ctx.ensure_object(dict) -@auracast.command('discover-broadcasts') +@auracast.command('scan') @click.option( '--filter-duplicates', is_flag=True, default=False, help='Filter duplicates' ) @@ -392,14 +599,50 @@ def auracast( '--sync-timeout', metavar='SYNC_TIMEOUT', type=float, - default=5.0, + default=AURACAST_DEFAULT_SYNC_TIMEOUT, help='Sync timeout (in seconds)', ) @click.argument('transport') @click.pass_context -def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport): - """Discover public broadcasts""" - asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport)) +def scan(ctx, filter_duplicates, sync_timeout, transport): + """Scan for public broadcasts""" + asyncio.run(run_scan(filter_duplicates, sync_timeout, transport)) + + +@auracast.command('assist') +@click.option( + '--broadcast-name', + metavar='BROADCAST_NAME', + help='Broadcast Name to tune to', +) +@click.option( + '--source-id', + metavar='SOURCE_ID', + type=int, + help='Source ID (for remove-source command)', +) +@click.option( + '--command', + type=click.Choice( + ['monitor-state', 'add-source', 'modify-source', 'remove-source'] + ), + required=True, +) +@click.argument('transport') +@click.argument('address') +@click.pass_context +def assist(ctx, broadcast_name, source_id, command, transport, address): + """Scan for broadcasts on behalf of a audio server""" + asyncio.run(run_assist(broadcast_name, source_id, command, transport, address)) + + +@auracast.command('pair') +@click.argument('transport') +@click.argument('address') +@click.pass_context +def pair(ctx, transport, address): + """Pair with an audio server""" + asyncio.run(run_pair(transport, address)) def main(): diff --git a/apps/device_info.py b/apps/device_info.py index 3b885c9d..71489795 100644 --- a/apps/device_info.py +++ b/apps/device_info.py @@ -106,7 +106,7 @@ async def show_battery_level( if battery_service.battery_level: print( - color(' Battery Level: ', 'green'), + color(' Battery Level:', 'green'), await battery_service.battery_level.read_value(), ) @@ -130,32 +130,35 @@ async def show_tmas( # ----------------------------------------------------------------------------- async def show_device_info(peer, done: Optional[asyncio.Future]) -> None: - # Discover all services - print(color('### Discovering Services and Characteristics', 'magenta')) - await peer.discover_services() - for service in peer.services: - await service.discover_characteristics() + try: + # Discover all services + print(color('### Discovering Services and Characteristics', 'magenta')) + await peer.discover_services() + for service in peer.services: + await service.discover_characteristics() - print(color('=== Services ===', 'yellow')) - show_services(peer.services) - print() + print(color('=== Services ===', 'yellow')) + show_services(peer.services) + print() - if gap_service := peer.create_service_proxy(GenericAccessServiceProxy): - await try_show(show_gap_information, gap_service) + if gap_service := peer.create_service_proxy(GenericAccessServiceProxy): + await try_show(show_gap_information, gap_service) - if device_information_service := peer.create_service_proxy( - DeviceInformationServiceProxy - ): - await try_show(show_device_information, device_information_service) + if device_information_service := peer.create_service_proxy( + DeviceInformationServiceProxy + ): + await try_show(show_device_information, device_information_service) - if battery_service := peer.create_service_proxy(BatteryServiceProxy): - await try_show(show_battery_level, battery_service) + if battery_service := peer.create_service_proxy(BatteryServiceProxy): + await try_show(show_battery_level, battery_service) - if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy): - await try_show(show_tmas, tmas) + if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy): + await try_show(show_tmas, tmas) - if done is not None: - done.set_result(None) + if done is not None: + done.set_result(None) + except asyncio.CancelledError: + print(color('!!! Operation canceled', 'red')) # ----------------------------------------------------------------------------- diff --git a/apps/lea_unicast/app.py b/apps/lea_unicast/app.py index ae3b4422..b1f2b3dc 100644 --- a/apps/lea_unicast/app.py +++ b/apps/lea_unicast/app.py @@ -33,7 +33,6 @@ import ctypes import wasmtime import wasmtime.loader import liblc3 # type: ignore -import logging import click import aiohttp.web @@ -43,7 +42,7 @@ from bumble.core import AdvertisingData from bumble.colors import color from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.transport import open_transport -from bumble.profiles import bap +from bumble.profiles import ascs, bap, pacs from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket # ----------------------------------------------------------------------------- @@ -57,8 +56,8 @@ logger = logging.getLogger(__name__) DEFAULT_UI_PORT = 7654 -def _sink_pac_record() -> bap.PacRecord: - return bap.PacRecord( +def _sink_pac_record() -> pacs.PacRecord: + return pacs.PacRecord( coding_format=CodingFormat(CodecID.LC3), codec_specific_capabilities=bap.CodecSpecificCapabilities( supported_sampling_frequencies=( @@ -79,8 +78,8 @@ def _sink_pac_record() -> bap.PacRecord: ) -def _source_pac_record() -> bap.PacRecord: - return bap.PacRecord( +def _source_pac_record() -> pacs.PacRecord: + return pacs.PacRecord( coding_format=CodingFormat(CodecID.LC3), codec_specific_capabilities=bap.CodecSpecificCapabilities( supported_sampling_frequencies=( @@ -447,7 +446,7 @@ class Speaker: ) self.device.add_service( - bap.PublishedAudioCapabilitiesService( + pacs.PublishedAudioCapabilitiesService( supported_source_context=bap.ContextType(0xFFFF), available_source_context=bap.ContextType(0xFFFF), supported_sink_context=bap.ContextType(0xFFFF), # All context types @@ -461,10 +460,10 @@ class Speaker: ) ) - ascs = bap.AudioStreamControlService( + ascs_service = ascs.AudioStreamControlService( self.device, sink_ase_id=[1], source_ase_id=[2] ) - self.device.add_service(ascs) + self.device.add_service(ascs_service) advertising_data = bytes( AdvertisingData( @@ -479,7 +478,7 @@ class Speaker: ), ( AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, - bytes(bap.PublishedAudioCapabilitiesService.UUID), + bytes(pacs.PublishedAudioCapabilitiesService.UUID), ), ] ) @@ -496,11 +495,11 @@ class Speaker: self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) def on_ase_state_change(ase: bap.AseStateMachine) -> None: - if ase.state == bap.AseStateMachine.State.STREAMING: + if ase.state == ascs.AseStateMachine.State.STREAMING: codec_config = ase.codec_specific_configuration assert isinstance(codec_config, bap.CodecSpecificConfiguration) assert ase.cis_link - if ase.role == bap.AudioRole.SOURCE: + if ase.role == ascs.AudioRole.SOURCE: ase.cis_link.abort_on( 'disconnection', lc3_source_task( @@ -516,10 +515,10 @@ class Speaker: ) else: ase.cis_link.sink = functools.partial(on_pdu, ase=ase) - elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED: + elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED: codec_config = ase.codec_specific_configuration assert isinstance(codec_config, bap.CodecSpecificConfiguration) - if ase.role == bap.AudioRole.SOURCE: + if ase.role == ascs.AudioRole.SOURCE: setup_encoders( codec_config.sampling_frequency.hz, codec_config.frame_duration.us, @@ -532,7 +531,7 @@ class Speaker: codec_config.audio_channel_allocation.channel_count, ) - for ase in ascs.ase_state_machines.values(): + for ase in ascs_service.ase_state_machines.values(): ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) await self.device.power_on() diff --git a/bumble/device.py b/bumble/device.py index 1b5b4840..b60f296a 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -113,6 +113,7 @@ from .hci import ( HCI_LE_Periodic_Advertising_Create_Sync_Command, HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command, HCI_LE_Periodic_Advertising_Report_Event, + HCI_LE_Periodic_Advertising_Sync_Transfer_Command, HCI_LE_Periodic_Advertising_Terminate_Sync_Command, HCI_LE_Enable_Encryption_Command, HCI_LE_Extended_Advertising_Report_Event, @@ -971,20 +972,24 @@ class PeriodicAdvertisingSync(EventEmitter): response = await self.device.send_command( HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(), ) - if response.status == HCI_SUCCESS: + if response.return_parameters == HCI_SUCCESS: if self in self.device.periodic_advertising_syncs: self.device.periodic_advertising_syncs.remove(self) return if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST): self.state = self.State.TERMINATED - await self.device.send_command( - HCI_LE_Periodic_Advertising_Terminate_Sync_Command( - sync_handle=self.sync_handle + if self.sync_handle is not None: + await self.device.send_command( + HCI_LE_Periodic_Advertising_Terminate_Sync_Command( + sync_handle=self.sync_handle + ) ) - ) self.device.periodic_advertising_syncs.remove(self) + async def transfer(self, connection: Connection, service_data: int = 0) -> None: + await connection.transfer_periodic_sync(self.sync_handle, service_data) + def on_establishment( self, status, @@ -1501,11 +1506,9 @@ class Connection(CompositeEventEmitter): try: await asyncio.wait_for(self.device.abort_on('flush', abort), timeout) - except asyncio.TimeoutError: - pass - - self.remove_listener('disconnection', abort.set_result) - self.remove_listener('disconnection_failure', abort.set_exception) + finally: + self.remove_listener('disconnection', abort.set_result) + self.remove_listener('disconnection_failure', abort.set_exception) async def set_data_length(self, tx_octets, tx_time) -> None: return await self.device.set_data_length(self, tx_octets, tx_time) @@ -1536,6 +1539,11 @@ class Connection(CompositeEventEmitter): async def get_phy(self): return await self.device.get_connection_phy(self) + async def transfer_periodic_sync( + self, sync_handle: int, service_data: int = 0 + ) -> None: + await self.device.transfer_periodic_sync(self, sync_handle, service_data) + # [Classic only] async def request_remote_name(self): return await self.device.request_remote_name(self) @@ -2997,18 +3005,47 @@ class Device(CompositeEventEmitter): ] = None, own_address_type: int = OwnAddressType.RANDOM, timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, + always_resolve: bool = False, ) -> Connection: ''' Request a connection to a peer. - When transport is BLE, this method cannot be called if there is already a + + When the transport is BLE, this method cannot be called if there is already a pending connection. - connection_parameters_preferences: (BLE only, ignored for BR/EDR) - * None: use the 1M PHY with default parameters - * map: each entry has a PHY as key and a ConnectionParametersPreferences - object as value + Args: + peer_address: + Address or name of the device to connect to. + If a string is passed: + If the string is an address followed by a `@` suffix, the `always_resolve` + argument is implicitly set to True, so the connection is made to the + address after resolution. + If the string is any other address, the connection is made to that + address (with or without address resolution, depending on the + `always_resolve` argument). + For any other string, a scan for devices using that string as their name + is initiated, and a connection to the first matching device's address + is made. In that case, `always_resolve` is ignored. - own_address_type: (BLE only) + connection_parameters_preferences: + (BLE only, ignored for BR/EDR) + * None: use the 1M PHY with default parameters + * map: each entry has a PHY as key and a ConnectionParametersPreferences + object as value + + own_address_type: + (BLE only, ignored for BR/EDR) + OwnAddressType.RANDOM to use this device's random address, or + OwnAddressType.PUBLIC to use this device's public address. + + timeout: + Maximum time to wait for a connection to be established, in seconds. + Pass None for an unlimited time. + + always_resolve: + (BLE only, ignored for BR/EDR) + If True, always initiate a scan, resolving addresses, and connect to the + address that resolves to `peer_address`. ''' # Check parameters @@ -3027,11 +3064,19 @@ class Device(CompositeEventEmitter): if isinstance(peer_address, str): try: - peer_address = Address.from_string_for_transport( - peer_address, transport - ) + if transport == BT_LE_TRANSPORT and peer_address.endswith('@'): + peer_address = Address.from_string_for_transport( + peer_address[:-1], transport + ) + always_resolve = True + logger.debug('forcing address resolution') + else: + peer_address = Address.from_string_for_transport( + peer_address, transport + ) except (InvalidArgumentError, ValueError): # If the address is not parsable, assume it is a name instead + always_resolve = False logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( peer_address, transport @@ -3046,6 +3091,12 @@ class Device(CompositeEventEmitter): assert isinstance(peer_address, Address) + if transport == BT_LE_TRANSPORT and always_resolve: + logger.debug('resolving address') + peer_address = await self.find_peer_by_identity_address( + peer_address + ) # TODO: timeout + def on_connection(connection): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection event against peer address @@ -3547,15 +3598,25 @@ class Device(CompositeEventEmitter): check_result=True, ) + async def transfer_periodic_sync( + self, connection: Connection, sync_handle: int, service_data: int = 0 + ) -> None: + return await self.send_command( + HCI_LE_Periodic_Advertising_Sync_Transfer_Command( + connection_handle=connection.handle, + service_data=service_data, + sync_handle=sync_handle, + ), check_result=True + ) + async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT): """ - Scan for a peer with a give name and return its address and transport + Scan for a peer with a given name and return its address. """ # Create a future to wait for an address to be found peer_address = asyncio.get_running_loop().create_future() - # Scan/inquire with event handlers to handle scan/inquiry results def on_peer_found(address, ad_data): local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) if local_name is None: @@ -3564,13 +3625,13 @@ class Device(CompositeEventEmitter): if local_name.decode('utf-8') == name: peer_address.set_result(address) - handler = None + listener = None was_scanning = self.scanning was_discovering = self.discovering try: if transport == BT_LE_TRANSPORT: event_name = 'advertisement' - handler = self.on( + listener = self.on( event_name, lambda advertisement: on_peer_found( advertisement.address, advertisement.data @@ -3582,7 +3643,7 @@ class Device(CompositeEventEmitter): elif transport == BT_BR_EDR_TRANSPORT: event_name = 'inquiry_result' - handler = self.on( + listener = self.on( event_name, lambda address, class_of_device, eir_data, rssi: on_peer_found( address, eir_data @@ -3596,14 +3657,60 @@ class Device(CompositeEventEmitter): return await self.abort_on('flush', peer_address) finally: - if handler is not None: - self.remove_listener(event_name, handler) + if listener is not None: + self.remove_listener(event_name, listener) if transport == BT_LE_TRANSPORT and not was_scanning: await self.stop_scanning() elif transport == BT_BR_EDR_TRANSPORT and not was_discovering: await self.stop_discovery() + async def find_peer_by_identity_address(self, identity_address: Address) -> Address: + """ + Scan for a peer with a resolvable address that can be resolved to a given + identity address. + """ + + # Create a future to wait for an address to be found + peer_address = asyncio.get_running_loop().create_future() + + def on_peer_found(address, _): + if address == identity_address: + if not peer_address.done(): + logger.debug(f'*** Matching public address found for {address}') + peer_address.set_result(address) + return + + if address.is_resolvable: + resolved_address = self.address_resolver.resolve(address) + if resolved_address == identity_address: + if not peer_address.done(): + logger.debug(f'*** Matching identity found for {address}') + peer_address.set_result(address) + return + + was_scanning = self.scanning + event_name = 'advertisement' + listener = None + try: + listener = self.on( + event_name, + lambda advertisement: on_peer_found( + advertisement.address, advertisement.data + ), + ) + + if not self.scanning: + await self.start_scanning(filter_duplicates=True) + + return await self.abort_on('flush', peer_address) + finally: + if listener is not None: + self.remove_listener(event_name, listener) + + if not was_scanning: + await self.stop_scanning() + @property def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: return self.smp_manager.pairing_config_factory diff --git a/bumble/gatt.py b/bumble/gatt.py index 896cec01..438c17cf 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -39,7 +39,7 @@ from typing import ( ) from bumble.colors import color -from bumble.core import UUID +from bumble.core import BaseBumbleError, UUID from bumble.att import Attribute, AttributeValue if TYPE_CHECKING: @@ -320,6 +320,11 @@ def show_services(services: Iterable[Service]) -> None: print(color(' ' + str(descriptor), 'green')) +# ----------------------------------------------------------------------------- +class InvalidServiceError(BaseBumbleError): + """The service is not compliant with the spec/profile""" + + # ----------------------------------------------------------------------------- class Service(Attribute): ''' diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 68b829a3..f2b8df65 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -283,6 +283,8 @@ class Client: self.services = [] self.cached_values = {} + connection.on('disconnection', self.on_disconnection) + def send_gatt_pdu(self, pdu: bytes) -> None: self.connection.send_l2cap_pdu(ATT_CID, pdu) @@ -1072,6 +1074,10 @@ class Client: ) ) + def on_disconnection(self, _) -> None: + if self.pending_response and not self.pending_response.done(): + self.pending_response.cancel() + def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' diff --git a/bumble/hci.py b/bumble/hci.py index 7e83f2ff..af39976c 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4529,18 +4529,6 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command): ''' -# ----------------------------------------------------------------------------- -@HCI_Command.command([('sync_handle', 2), ('enable', 1)]) -class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command): - ''' - See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command - ''' - - class Enable(enum.IntFlag): - REPORTING_ENABLED = 1 << 0 - DUPLICATE_FILTERING_ENABLED = 1 << 1 - - # ----------------------------------------------------------------------------- @HCI_Command.command( [ @@ -4576,6 +4564,32 @@ class HCI_LE_Set_Privacy_Mode_Command(HCI_Command): return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode) +# ----------------------------------------------------------------------------- +@HCI_Command.command([('sync_handle', 2), ('enable', 1)]) +class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command + ''' + + class Enable(enum.IntFlag): + REPORTING_ENABLED = 1 << 0 + DUPLICATE_FILTERING_ENABLED = 1 << 1 + + +# ----------------------------------------------------------------------------- +@HCI_Command.command( + fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)], + return_parameters_fields=[ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ], +) +class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command + ''' + + # ----------------------------------------------------------------------------- @HCI_Command.command( fields=[ diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py new file mode 100644 index 00000000..98b4b8ba --- /dev/null +++ b/bumble/profiles/ascs.py @@ -0,0 +1,738 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for + +"""LE Audio - Audio Stream Control Service""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import enum +import logging +import struct +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + +from bumble import colors +from bumble.profiles.bap import CodecSpecificConfiguration +from bumble import device +from bumble import gatt +from bumble import gatt_client +from bumble import hci + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# ASE Operations +# ----------------------------------------------------------------------------- + + +class ASE_Operation: + ''' + See Audio Stream Control Service - 5 ASE Control operations. + ''' + + classes: Dict[int, Type[ASE_Operation]] = {} + op_code: int + name: str + fields: Optional[Sequence[Any]] = None + ase_id: List[int] + + class Opcode(enum.IntEnum): + # fmt: off + CONFIG_CODEC = 0x01 + CONFIG_QOS = 0x02 + ENABLE = 0x03 + RECEIVER_START_READY = 0x04 + DISABLE = 0x05 + RECEIVER_STOP_READY = 0x06 + UPDATE_METADATA = 0x07 + RELEASE = 0x08 + + @staticmethod + def from_bytes(pdu: bytes) -> ASE_Operation: + op_code = pdu[0] + + cls = ASE_Operation.classes.get(op_code) + if cls is None: + instance = ASE_Operation(pdu) + instance.name = ASE_Operation.Opcode(op_code).name + instance.op_code = op_code + return instance + self = cls.__new__(cls) + ASE_Operation.__init__(self, pdu) + if self.fields is not None: + self.init_from_bytes(pdu, 1) + return self + + @staticmethod + def subclass(fields): + def inner(cls: Type[ASE_Operation]): + try: + operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] + cls.name = operation.name + cls.op_code = operation + except: + raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') + cls.fields = fields + + # Register a factory for this class + ASE_Operation.classes[cls.op_code] = cls + + return cls + + return inner + + def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: + if self.fields is not None and kwargs: + hci.HCI_Object.init_from_fields(self, self.fields, kwargs) + if pdu is None: + pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( + kwargs, self.fields + ) + self.pdu = pdu + + def init_from_bytes(self, pdu: bytes, offset: int): + return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) + + def __bytes__(self) -> bytes: + return self.pdu + + def __str__(self) -> str: + result = f'{colors.color(self.name, "yellow")} ' + if fields := getattr(self, 'fields', None): + result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') + else: + if len(self.pdu) > 1: + result += f': {self.pdu.hex()}' + return result + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('target_latency', 1), + ('target_phy', 1), + ('codec_id', hci.CodingFormat.parse_from_bytes), + ('codec_specific_configuration', 'v'), + ], + ] +) +class ASE_Config_Codec(ASE_Operation): + ''' + See Audio Stream Control Service 5.1 - Config Codec Operation + ''' + + target_latency: List[int] + target_phy: List[int] + codec_id: List[hci.CodingFormat] + codec_specific_configuration: List[bytes] + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('cig_id', 1), + ('cis_id', 1), + ('sdu_interval', 3), + ('framing', 1), + ('phy', 1), + ('max_sdu', 2), + ('retransmission_number', 1), + ('max_transport_latency', 2), + ('presentation_delay', 3), + ], + ] +) +class ASE_Config_QOS(ASE_Operation): + ''' + See Audio Stream Control Service 5.2 - Config Qos Operation + ''' + + cig_id: List[int] + cis_id: List[int] + sdu_interval: List[int] + framing: List[int] + phy: List[int] + max_sdu: List[int] + retransmission_number: List[int] + max_transport_latency: List[int] + presentation_delay: List[int] + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Enable(ASE_Operation): + ''' + See Audio Stream Control Service 5.3 - Enable Operation + ''' + + metadata: bytes + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Start_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.4 - Receiver Start Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Disable(ASE_Operation): + ''' + See Audio Stream Control Service 5.5 - Disable Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Stop_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Update_Metadata(ASE_Operation): + ''' + See Audio Stream Control Service 5.7 - Update Metadata Operation + ''' + + metadata: List[bytes] + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Release(ASE_Operation): + ''' + See Audio Stream Control Service 5.8 - Release Operation + ''' + + +class AseResponseCode(enum.IntEnum): + # fmt: off + SUCCESS = 0x00 + UNSUPPORTED_OPCODE = 0x01 + INVALID_LENGTH = 0x02 + INVALID_ASE_ID = 0x03 + INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 + INVALID_ASE_DIRECTION = 0x05 + UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 + UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 + REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 + INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 + UNSUPPORTED_METADATA = 0x0A + REJECTED_METADATA = 0x0B + INVALID_METADATA = 0x0C + INSUFFICIENT_RESOURCES = 0x0D + UNSPECIFIED_ERROR = 0x0E + + +class AseReasonCode(enum.IntEnum): + # fmt: off + NONE = 0x00 + CODEC_ID = 0x01 + CODEC_SPECIFIC_CONFIGURATION = 0x02 + SDU_INTERVAL = 0x03 + FRAMING = 0x04 + PHY = 0x05 + MAXIMUM_SDU_SIZE = 0x06 + RETRANSMISSION_NUMBER = 0x07 + MAX_TRANSPORT_LATENCY = 0x08 + PRESENTATION_DELAY = 0x09 + INVALID_ASE_CIS_MAPPING = 0x0A + + +# ----------------------------------------------------------------------------- +class AudioRole(enum.IntEnum): + SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST + SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER + + +# ----------------------------------------------------------------------------- +class AseStateMachine(gatt.Characteristic): + class State(enum.IntEnum): + # fmt: off + IDLE = 0x00 + CODEC_CONFIGURED = 0x01 + QOS_CONFIGURED = 0x02 + ENABLING = 0x03 + STREAMING = 0x04 + DISABLING = 0x05 + RELEASING = 0x06 + + cis_link: Optional[device.CisLink] = None + + # Additional parameters in CODEC_CONFIGURED State + preferred_framing = 0 # Unframed PDU supported + preferred_phy = 0 + preferred_retransmission_number = 13 + preferred_max_transport_latency = 100 + supported_presentation_delay_min = 0 + supported_presentation_delay_max = 0 + preferred_presentation_delay_min = 0 + preferred_presentation_delay_max = 0 + codec_id = hci.CodingFormat(hci.CodecID.LC3) + codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b'' + + # Additional parameters in QOS_CONFIGURED State + cig_id = 0 + cis_id = 0 + sdu_interval = 0 + framing = 0 + phy = 0 + max_sdu = 0 + retransmission_number = 0 + max_transport_latency = 0 + presentation_delay = 0 + + # Additional parameters in ENABLING, STREAMING, DISABLING State + # TODO: Parse this + metadata = b'' + + def __init__( + self, + role: AudioRole, + ase_id: int, + service: AudioStreamControlService, + ) -> None: + self.service = service + self.ase_id = ase_id + self._state = AseStateMachine.State.IDLE + self.role = role + + uuid = ( + gatt.GATT_SINK_ASE_CHARACTERISTIC + if role == AudioRole.SINK + else gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + super().__init__( + uuid=uuid, + properties=gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.READABLE, + value=gatt.CharacteristicValue(read=self.on_read), + ) + + self.service.device.on('cis_request', self.on_cis_request) + self.service.device.on('cis_establishment', self.on_cis_establishment) + + def on_cis_request( + self, + acl_connection: device.Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ) -> None: + if ( + cig_id == self.cig_id + and cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + acl_connection.abort_on( + 'flush', self.service.device.accept_cis_request(cis_handle) + ) + + def on_cis_establishment(self, cis_link: device.CisLink) -> None: + if ( + cis_link.cig_id == self.cig_id + and cis_link.cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + cis_link.on('disconnection', self.on_cis_disconnection) + + async def post_cis_established(): + await self.service.device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=cis_link.handle, + data_path_direction=self.role, + data_path_id=0x00, # Fixed HCI + codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ) + ) + if self.role == AudioRole.SINK: + self.state = self.State.STREAMING + await self.service.device.notify_subscribers(self, self.value) + + cis_link.acl_connection.abort_on('flush', post_cis_established()) + self.cis_link = cis_link + + def on_cis_disconnection(self, _reason) -> None: + self.cis_link = None + + def on_config_codec( + self, + target_latency: int, + target_phy: int, + codec_id: hci.CodingFormat, + codec_specific_configuration: bytes, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + self.State.IDLE, + self.State.CODEC_CONFIGURED, + self.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.max_transport_latency = target_latency + self.phy = target_phy + self.codec_id = codec_id + if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC: + self.codec_specific_configuration = codec_specific_configuration + else: + self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes( + codec_specific_configuration + ) + + self.state = self.State.CODEC_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_config_qos( + self, + cig_id: int, + cis_id: int, + sdu_interval: int, + framing: int, + phy: int, + max_sdu: int, + retransmission_number: int, + max_transport_latency: int, + presentation_delay: int, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.CODEC_CONFIGURED, + AseStateMachine.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.cig_id = cig_id + self.cis_id = cis_id + self.sdu_interval = sdu_interval + self.framing = framing + self.phy = phy + self.max_sdu = max_sdu + self.retransmission_number = retransmission_number + self.max_transport_latency = max_transport_latency + self.presentation_delay = presentation_delay + + self.state = self.State.QOS_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.QOS_CONFIGURED: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.metadata = metadata + self.state = self.State.ENABLING + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.ENABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.STREAMING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + if self.role == AudioRole.SINK: + self.state = self.State.QOS_CONFIGURED + else: + self.state = self.State.DISABLING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if ( + self.role != AudioRole.SOURCE + or self.state != AseStateMachine.State.DISABLING + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.QOS_CONFIGURED + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_update_metadata( + self, metadata: bytes + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.metadata = metadata + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state == AseStateMachine.State.IDLE: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.RELEASING + + async def remove_cis_async(): + await self.service.device.send_command( + hci.HCI_LE_Remove_ISO_Data_Path_Command( + connection_handle=self.cis_link.handle, + data_path_direction=self.role, + ) + ) + self.state = self.State.IDLE + await self.service.device.notify_subscribers(self, self.value) + + self.service.device.abort_on('flush', remove_cis_async()) + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') + self._state = new_state + self.emit('state_change') + + @property + def value(self): + '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' + + if self.state == self.State.CODEC_CONFIGURED: + codec_specific_configuration_bytes = bytes( + self.codec_specific_configuration + ) + additional_parameters = ( + struct.pack( + ' bytes: + return self.value + + def __str__(self) -> str: + return ( + f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' + f'state={self._state.name})' + ) + + +# ----------------------------------------------------------------------------- +class AudioStreamControlService(gatt.TemplateService): + UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE + + ase_state_machines: Dict[int, AseStateMachine] + ase_control_point: gatt.Characteristic + _active_client: Optional[device.Connection] = None + + def __init__( + self, + device: device.Device, + source_ase_id: Sequence[int] = (), + sink_ase_id: Sequence[int] = (), + ) -> None: + self.device = device + self.ase_state_machines = { + **{ + id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) + for id in sink_ase_id + }, + **{ + id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) + for id in source_ase_id + }, + } # ASE state machines, by ASE ID + + self.ase_control_point = gatt.Characteristic( + uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.WRITE + | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.WRITEABLE, + value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), + ) + + super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) + + def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): + if ase := self.ase_state_machines.get(ase_id): + handler = getattr(ase, 'on_' + opcode.name.lower()) + return (ase_id, *handler(*args)) + else: + return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + + def _on_client_disconnected(self, _reason: int) -> None: + for ase in self.ase_state_machines.values(): + ase.state = AseStateMachine.State.IDLE + self._active_client = None + + def on_write_ase_control_point(self, connection, data): + if not self._active_client and connection: + self._active_client = connection + connection.once('disconnection', self._on_client_disconnected) + + operation = ASE_Operation.from_bytes(data) + responses = [] + logger.debug(f'*** ASCS Write {operation} ***') + + if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: + for ase_id, *args in zip( + operation.ase_id, + operation.target_latency, + operation.target_phy, + operation.codec_id, + operation.codec_specific_configuration, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: + for ase_id, *args in zip( + operation.ase_id, + operation.cig_id, + operation.cis_id, + operation.sdu_interval, + operation.framing, + operation.phy, + operation.max_sdu, + operation.retransmission_number, + operation.max_transport_latency, + operation.presentation_delay, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.ENABLE, + ASE_Operation.Opcode.UPDATE_METADATA, + ): + for ase_id, *args in zip( + operation.ase_id, + operation.metadata, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.RECEIVER_START_READY, + ASE_Operation.Opcode.DISABLE, + ASE_Operation.Opcode.RECEIVER_STOP_READY, + ASE_Operation.Opcode.RELEASE, + ): + for ase_id in operation.ase_id: + responses.append(self.on_operation(operation.op_code, ase_id, [])) + + control_point_notification = bytes( + [operation.op_code, len(responses)] + ) + b''.join(map(bytes, responses)) + self.device.abort_on( + 'flush', + self.device.notify_subscribers( + self.ase_control_point, control_point_notification + ), + ) + + for ase_id, *_ in responses: + if ase := self.ase_state_machines.get(ase_id): + self.device.abort_on( + 'flush', + self.device.notify_subscribers(ase, ase.value), + ) + + +# ----------------------------------------------------------------------------- +class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = AudioStreamControlService + + sink_ase: List[gatt_client.CharacteristicProxy] + source_ase: List[gatt_client.CharacteristicProxy] + ase_control_point: gatt_client.CharacteristicProxy + + def __init__(self, service_proxy: gatt_client.ServiceProxy): + self.service_proxy = service_proxy + + self.sink_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SINK_ASE_CHARACTERISTIC + ) + self.source_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + self.ase_control_point = service_proxy.get_characteristics_by_uuid( + gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC + )[0] diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 117e95e6..8a00eafe 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -24,15 +24,12 @@ import enum import struct import functools import logging -from typing import Optional, List, Union, Type, Dict, Any, Tuple +from typing import List from typing_extensions import Self from bumble import core -from bumble import colors -from bumble import device from bumble import hci from bumble import gatt -from bumble import gatt_client from bumble import utils from bumble.profiles import le_audio @@ -251,231 +248,6 @@ class AnnouncementType(utils.OpenIntEnum): TARGETED = 0x01 -# ----------------------------------------------------------------------------- -# ASE Operations -# ----------------------------------------------------------------------------- - - -class ASE_Operation: - ''' - See Audio Stream Control Service - 5 ASE Control operations. - ''' - - classes: Dict[int, Type[ASE_Operation]] = {} - op_code: int - name: str - fields: Optional[Sequence[Any]] = None - ase_id: List[int] - - class Opcode(enum.IntEnum): - # fmt: off - CONFIG_CODEC = 0x01 - CONFIG_QOS = 0x02 - ENABLE = 0x03 - RECEIVER_START_READY = 0x04 - DISABLE = 0x05 - RECEIVER_STOP_READY = 0x06 - UPDATE_METADATA = 0x07 - RELEASE = 0x08 - - @staticmethod - def from_bytes(pdu: bytes) -> ASE_Operation: - op_code = pdu[0] - - cls = ASE_Operation.classes.get(op_code) - if cls is None: - instance = ASE_Operation(pdu) - instance.name = ASE_Operation.Opcode(op_code).name - instance.op_code = op_code - return instance - self = cls.__new__(cls) - ASE_Operation.__init__(self, pdu) - if self.fields is not None: - self.init_from_bytes(pdu, 1) - return self - - @staticmethod - def subclass(fields): - def inner(cls: Type[ASE_Operation]): - try: - operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] - cls.name = operation.name - cls.op_code = operation - except: - raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') - cls.fields = fields - - # Register a factory for this class - ASE_Operation.classes[cls.op_code] = cls - - return cls - - return inner - - def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: - if self.fields is not None and kwargs: - hci.HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( - kwargs, self.fields - ) - self.pdu = pdu - - def init_from_bytes(self, pdu: bytes, offset: int): - return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - - def __bytes__(self) -> bytes: - return self.pdu - - def __str__(self) -> str: - result = f'{colors.color(self.name, "yellow")} ' - if fields := getattr(self, 'fields', None): - result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') - else: - if len(self.pdu) > 1: - result += f': {self.pdu.hex()}' - return result - - -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('target_latency', 1), - ('target_phy', 1), - ('codec_id', hci.CodingFormat.parse_from_bytes), - ('codec_specific_configuration', 'v'), - ], - ] -) -class ASE_Config_Codec(ASE_Operation): - ''' - See Audio Stream Control Service 5.1 - Config Codec Operation - ''' - - target_latency: List[int] - target_phy: List[int] - codec_id: List[hci.CodingFormat] - codec_specific_configuration: List[bytes] - - -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('cig_id', 1), - ('cis_id', 1), - ('sdu_interval', 3), - ('framing', 1), - ('phy', 1), - ('max_sdu', 2), - ('retransmission_number', 1), - ('max_transport_latency', 2), - ('presentation_delay', 3), - ], - ] -) -class ASE_Config_QOS(ASE_Operation): - ''' - See Audio Stream Control Service 5.2 - Config Qos Operation - ''' - - cig_id: List[int] - cis_id: List[int] - sdu_interval: List[int] - framing: List[int] - phy: List[int] - max_sdu: List[int] - retransmission_number: List[int] - max_transport_latency: List[int] - presentation_delay: List[int] - - -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) -class ASE_Enable(ASE_Operation): - ''' - See Audio Stream Control Service 5.3 - Enable Operation - ''' - - metadata: bytes - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Receiver_Start_Ready(ASE_Operation): - ''' - See Audio Stream Control Service 5.4 - Receiver Start Ready Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Disable(ASE_Operation): - ''' - See Audio Stream Control Service 5.5 - Disable Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Receiver_Stop_Ready(ASE_Operation): - ''' - See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) -class ASE_Update_Metadata(ASE_Operation): - ''' - See Audio Stream Control Service 5.7 - Update Metadata Operation - ''' - - metadata: List[bytes] - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Release(ASE_Operation): - ''' - See Audio Stream Control Service 5.8 - Release Operation - ''' - - -class AseResponseCode(enum.IntEnum): - # fmt: off - SUCCESS = 0x00 - UNSUPPORTED_OPCODE = 0x01 - INVALID_LENGTH = 0x02 - INVALID_ASE_ID = 0x03 - INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 - INVALID_ASE_DIRECTION = 0x05 - UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 - UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 - REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 - INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 - UNSUPPORTED_METADATA = 0x0A - REJECTED_METADATA = 0x0B - INVALID_METADATA = 0x0C - INSUFFICIENT_RESOURCES = 0x0D - UNSPECIFIED_ERROR = 0x0E - - -class AseReasonCode(enum.IntEnum): - # fmt: off - NONE = 0x00 - CODEC_ID = 0x01 - CODEC_SPECIFIC_CONFIGURATION = 0x02 - SDU_INTERVAL = 0x03 - FRAMING = 0x04 - PHY = 0x05 - MAXIMUM_SDU_SIZE = 0x06 - RETRANSMISSION_NUMBER = 0x07 - MAX_TRANSPORT_LATENCY = 0x08 - PRESENTATION_DELAY = 0x09 - INVALID_ASE_CIS_MAPPING = 0x0A - - -class AudioRole(enum.IntEnum): - SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST - SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER - - @dataclasses.dataclass class UnicastServerAdvertisingData: """Advertising Data for ASCS.""" @@ -683,54 +455,6 @@ class CodecSpecificConfiguration: ) -@dataclasses.dataclass -class PacRecord: - '''Published Audio Capabilities Service, Table 3.2/3.4.''' - - coding_format: hci.CodingFormat - codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] - metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata) - - @classmethod - def from_bytes(cls, data: bytes) -> PacRecord: - offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0) - codec_specific_capabilities_size = data[offset] - - offset += 1 - codec_specific_capabilities_bytes = data[ - offset : offset + codec_specific_capabilities_size - ] - offset += codec_specific_capabilities_size - metadata_size = data[offset] - offset += 1 - metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size]) - - codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] - if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: - codec_specific_capabilities = codec_specific_capabilities_bytes - else: - codec_specific_capabilities = CodecSpecificCapabilities.from_bytes( - codec_specific_capabilities_bytes - ) - - return PacRecord( - coding_format=coding_format, - codec_specific_capabilities=codec_specific_capabilities, - metadata=metadata, - ) - - def __bytes__(self) -> bytes: - capabilities_bytes = bytes(self.codec_specific_capabilities) - metadata_bytes = bytes(self.metadata) - return ( - bytes(self.coding_format) - + bytes([len(capabilities_bytes)]) - + capabilities_bytes - + bytes([len(metadata_bytes)]) - + metadata_bytes - ) - - @dataclasses.dataclass class BroadcastAudioAnnouncement: broadcast_id: int @@ -822,603 +546,3 @@ class BasicAudioAnnouncement: ) return cls(presentation_delay, subgroups) - - -# ----------------------------------------------------------------------------- -# Server -# ----------------------------------------------------------------------------- -class PublishedAudioCapabilitiesService(gatt.TemplateService): - UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE - - sink_pac: Optional[gatt.Characteristic] - sink_audio_locations: Optional[gatt.Characteristic] - source_pac: Optional[gatt.Characteristic] - source_audio_locations: Optional[gatt.Characteristic] - available_audio_contexts: gatt.Characteristic - supported_audio_contexts: gatt.Characteristic - - def __init__( - self, - supported_source_context: ContextType, - supported_sink_context: ContextType, - available_source_context: ContextType, - available_sink_context: ContextType, - sink_pac: Sequence[PacRecord] = (), - sink_audio_locations: Optional[AudioLocation] = None, - source_pac: Sequence[PacRecord] = (), - source_audio_locations: Optional[AudioLocation] = None, - ) -> None: - characteristics = [] - - self.supported_audio_contexts = gatt.Characteristic( - uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC, - properties=gatt.Characteristic.Properties.READ, - permissions=gatt.Characteristic.Permissions.READABLE, - value=struct.pack(' None: - self.service = service - self.ase_id = ase_id - self._state = AseStateMachine.State.IDLE - self.role = role - - uuid = ( - gatt.GATT_SINK_ASE_CHARACTERISTIC - if role == AudioRole.SINK - else gatt.GATT_SOURCE_ASE_CHARACTERISTIC - ) - super().__init__( - uuid=uuid, - properties=gatt.Characteristic.Properties.READ - | gatt.Characteristic.Properties.NOTIFY, - permissions=gatt.Characteristic.Permissions.READABLE, - value=gatt.CharacteristicValue(read=self.on_read), - ) - - self.service.device.on('cis_request', self.on_cis_request) - self.service.device.on('cis_establishment', self.on_cis_establishment) - - def on_cis_request( - self, - acl_connection: device.Connection, - cis_handle: int, - cig_id: int, - cis_id: int, - ) -> None: - if ( - cig_id == self.cig_id - and cis_id == self.cis_id - and self.state == self.State.ENABLING - ): - acl_connection.abort_on( - 'flush', self.service.device.accept_cis_request(cis_handle) - ) - - def on_cis_establishment(self, cis_link: device.CisLink) -> None: - if ( - cis_link.cig_id == self.cig_id - and cis_link.cis_id == self.cis_id - and self.state == self.State.ENABLING - ): - cis_link.on('disconnection', self.on_cis_disconnection) - - async def post_cis_established(): - await self.service.device.send_command( - hci.HCI_LE_Setup_ISO_Data_Path_Command( - connection_handle=cis_link.handle, - data_path_direction=self.role, - data_path_id=0x00, # Fixed HCI - codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), - controller_delay=0, - codec_configuration=b'', - ) - ) - if self.role == AudioRole.SINK: - self.state = self.State.STREAMING - await self.service.device.notify_subscribers(self, self.value) - - cis_link.acl_connection.abort_on('flush', post_cis_established()) - self.cis_link = cis_link - - def on_cis_disconnection(self, _reason) -> None: - self.cis_link = None - - def on_config_codec( - self, - target_latency: int, - target_phy: int, - codec_id: hci.CodingFormat, - codec_specific_configuration: bytes, - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - self.State.IDLE, - self.State.CODEC_CONFIGURED, - self.State.QOS_CONFIGURED, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.max_transport_latency = target_latency - self.phy = target_phy - self.codec_id = codec_id - if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC: - self.codec_specific_configuration = codec_specific_configuration - else: - self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes( - codec_specific_configuration - ) - - self.state = self.State.CODEC_CONFIGURED - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_config_qos( - self, - cig_id: int, - cis_id: int, - sdu_interval: int, - framing: int, - phy: int, - max_sdu: int, - retransmission_number: int, - max_transport_latency: int, - presentation_delay: int, - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.CODEC_CONFIGURED, - AseStateMachine.State.QOS_CONFIGURED, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.cig_id = cig_id - self.cis_id = cis_id - self.sdu_interval = sdu_interval - self.framing = framing - self.phy = phy - self.max_sdu = max_sdu - self.retransmission_number = retransmission_number - self.max_transport_latency = max_transport_latency - self.presentation_delay = presentation_delay - - self.state = self.State.QOS_CONFIGURED - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.QOS_CONFIGURED: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.metadata = le_audio.Metadata.from_bytes(metadata) - self.state = self.State.ENABLING - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.ENABLING: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.STREAMING - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.ENABLING, - AseStateMachine.State.STREAMING, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - if self.role == AudioRole.SINK: - self.state = self.State.QOS_CONFIGURED - else: - self.state = self.State.DISABLING - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if ( - self.role != AudioRole.SOURCE - or self.state != AseStateMachine.State.DISABLING - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.QOS_CONFIGURED - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_update_metadata( - self, metadata: bytes - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.ENABLING, - AseStateMachine.State.STREAMING, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.metadata = le_audio.Metadata.from_bytes(metadata) - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state == AseStateMachine.State.IDLE: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.RELEASING - - async def remove_cis_async(): - await self.service.device.send_command( - hci.HCI_LE_Remove_ISO_Data_Path_Command( - connection_handle=self.cis_link.handle, - data_path_direction=self.role, - ) - ) - self.state = self.State.IDLE - await self.service.device.notify_subscribers(self, self.value) - - self.service.device.abort_on('flush', remove_cis_async()) - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - @property - def state(self) -> State: - return self._state - - @state.setter - def state(self, new_state: State) -> None: - logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') - self._state = new_state - self.emit('state_change') - - @property - def value(self): - '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' - - if self.state == self.State.CODEC_CONFIGURED: - codec_specific_configuration_bytes = bytes( - self.codec_specific_configuration - ) - additional_parameters = ( - struct.pack( - ' bytes: - return self.value - - def __str__(self) -> str: - return ( - f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' - f'state={self._state.name})' - ) - - -class AudioStreamControlService(gatt.TemplateService): - UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE - - ase_state_machines: Dict[int, AseStateMachine] - ase_control_point: gatt.Characteristic - _active_client: Optional[device.Connection] = None - - def __init__( - self, - device: device.Device, - source_ase_id: Sequence[int] = [], - sink_ase_id: Sequence[int] = [], - ) -> None: - self.device = device - self.ase_state_machines = { - **{ - id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) - for id in sink_ase_id - }, - **{ - id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) - for id in source_ase_id - }, - } # ASE state machines, by ASE ID - - self.ase_control_point = gatt.Characteristic( - uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, - properties=gatt.Characteristic.Properties.WRITE - | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE - | gatt.Characteristic.Properties.NOTIFY, - permissions=gatt.Characteristic.Permissions.WRITEABLE, - value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), - ) - - super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) - - def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): - if ase := self.ase_state_machines.get(ase_id): - handler = getattr(ase, 'on_' + opcode.name.lower()) - return (ase_id, *handler(*args)) - else: - return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) - - def _on_client_disconnected(self, _reason: int) -> None: - for ase in self.ase_state_machines.values(): - ase.state = AseStateMachine.State.IDLE - self._active_client = None - - def on_write_ase_control_point(self, connection, data): - if not self._active_client and connection: - self._active_client = connection - connection.once('disconnection', self._on_client_disconnected) - - operation = ASE_Operation.from_bytes(data) - responses = [] - logger.debug(f'*** ASCS Write {operation} ***') - - if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: - for ase_id, *args in zip( - operation.ase_id, - operation.target_latency, - operation.target_phy, - operation.codec_id, - operation.codec_specific_configuration, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: - for ase_id, *args in zip( - operation.ase_id, - operation.cig_id, - operation.cis_id, - operation.sdu_interval, - operation.framing, - operation.phy, - operation.max_sdu, - operation.retransmission_number, - operation.max_transport_latency, - operation.presentation_delay, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.ENABLE, - ASE_Operation.Opcode.UPDATE_METADATA, - ): - for ase_id, *args in zip( - operation.ase_id, - operation.metadata, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.RECEIVER_START_READY, - ASE_Operation.Opcode.DISABLE, - ASE_Operation.Opcode.RECEIVER_STOP_READY, - ASE_Operation.Opcode.RELEASE, - ): - for ase_id in operation.ase_id: - responses.append(self.on_operation(operation.op_code, ase_id, [])) - - control_point_notification = bytes( - [operation.op_code, len(responses)] - ) + b''.join(map(bytes, responses)) - self.device.abort_on( - 'flush', - self.device.notify_subscribers( - self.ase_control_point, control_point_notification - ), - ) - - for ase_id, *_ in responses: - if ase := self.ase_state_machines.get(ase_id): - self.device.abort_on( - 'flush', - self.device.notify_subscribers(ase, ase.value), - ) - - -# ----------------------------------------------------------------------------- -# Client -# ----------------------------------------------------------------------------- -class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy): - SERVICE_CLASS = PublishedAudioCapabilitiesService - - sink_pac: Optional[gatt_client.CharacteristicProxy] = None - sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None - source_pac: Optional[gatt_client.CharacteristicProxy] = None - source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None - available_audio_contexts: gatt_client.CharacteristicProxy - supported_audio_contexts: gatt_client.CharacteristicProxy - - def __init__(self, service_proxy: gatt_client.ServiceProxy): - self.service_proxy = service_proxy - - self.available_audio_contexts = service_proxy.get_characteristics_by_uuid( - gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC - )[0] - self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC - )[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_PAC_CHARACTERISTIC - ): - self.sink_pac = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_PAC_CHARACTERISTIC - ): - self.source_pac = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC - ): - self.sink_audio_locations = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC - ): - self.source_audio_locations = characteristics[0] - - -class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): - SERVICE_CLASS = AudioStreamControlService - - sink_ase: List[gatt_client.CharacteristicProxy] - source_ase: List[gatt_client.CharacteristicProxy] - ase_control_point: gatt_client.CharacteristicProxy - - def __init__(self, service_proxy: gatt_client.ServiceProxy): - self.service_proxy = service_proxy - - self.sink_ase = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_ASE_CHARACTERISTIC - ) - self.source_ase = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_ASE_CHARACTERISTIC - ) - self.ase_control_point = service_proxy.get_characteristics_by_uuid( - gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC - )[0] diff --git a/bumble/profiles/bass.py b/bumble/profiles/bass.py index a12f44de..57531dbd 100644 --- a/bumble/profiles/bass.py +++ b/bumble/profiles/bass.py @@ -16,12 +16,17 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations +import dataclasses import logging -from typing import Optional +import struct +from typing import ClassVar, List, Optional, Sequence +from bumble import core from bumble import device from bumble import gatt from bumble import gatt_client +from bumble import hci from bumble import utils # ----------------------------------------------------------------------------- @@ -38,6 +43,284 @@ class ApplicationError(utils.OpenIntEnum): INVALID_SOURCE_ID = 0x81 +# ----------------------------------------------------------------------------- +def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes: + return bytes([len(subgroups)]) + b"".join( + struct.pack(" List[SubgroupInfo]: + num_subgroups = data[0] + offset = 1 + subgroups = [] + for _ in range(num_subgroups): + bis_sync = struct.unpack(" ControlPointOperation: + op_code = data[0] + + if op_code == cls.OpCode.REMOTE_SCAN_STOPPED: + return RemoteScanStoppedOperation() + + if op_code == cls.OpCode.REMOTE_SCAN_STARTED: + return RemoteScanStartedOperation() + + if op_code == cls.OpCode.ADD_SOURCE: + return AddSourceOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.MODIFY_SOURCE: + return ModifySourceOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.SET_BROADCAST_CODE: + return SetBroadcastCodeOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.REMOVE_SOURCE: + return RemoveSourceOperation.from_parameters(data[1:]) + + raise core.InvalidArgumentError("invalid op code") + + def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None: + self.op_code = op_code + self.parameters = parameters + + def __bytes__(self) -> bytes: + return bytes([self.op_code]) + self.parameters + + +class RemoteScanStoppedOperation(ControlPointOperation): + def __init__(self) -> None: + super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED) + + +class RemoteScanStartedOperation(ControlPointOperation): + def __init__(self) -> None: + super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED) + + +class AddSourceOperation(ControlPointOperation): + @classmethod + def from_parameters(cls, parameters: bytes) -> AddSourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE + instance.parameters = parameters + instance.advertiser_address = hci.Address.parse_address_preceded_by_type( + parameters, 1 + )[1] + instance.advertising_sid = parameters[7] + instance.broadcast_id = int.from_bytes(parameters[8:11], "little") + instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11]) + instance.pa_interval = struct.unpack(" None: + super().__init__( + ControlPointOperation.OpCode.ADD_SOURCE, + struct.pack( + " ModifySourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE + instance.parameters = parameters + instance.source_id = parameters[0] + instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1]) + instance.pa_interval = struct.unpack(" None: + super().__init__( + ControlPointOperation.OpCode.MODIFY_SOURCE, + struct.pack(" SetBroadcastCodeOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE + instance.parameters = parameters + instance.source_id = parameters[0] + instance.broadcast_code = parameters[1:17] + return instance + + def __init__( + self, + source_id: int, + broadcast_code: bytes, + ) -> None: + super().__init__( + ControlPointOperation.OpCode.SET_BROADCAST_CODE, + bytes([source_id]) + broadcast_code, + ) + self.source_id = source_id + self.broadcast_code = broadcast_code + + if len(self.broadcast_code) != 16: + raise core.InvalidArgumentError("broadcast_code must be 16 bytes") + + +class RemoveSourceOperation(ControlPointOperation): + @classmethod + def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE + instance.parameters = parameters + instance.source_id = parameters[0] + return instance + + def __init__(self, source_id: int) -> None: + super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id])) + self.source_id = source_id + + +@dataclasses.dataclass +class BroadcastReceiveState: + class PeriodicAdvertisingSyncState(utils.OpenIntEnum): + NOT_SYNCHRONIZED_TO_PA = 0x00 + SYNCINFO_REQUEST = 0x01 + SYNCHRONIZED_TO_PA = 0x02 + FAILED_TO_SYNCHRONIZE_TO_PA = 0x03 + NO_PAST = 0x04 + + class BigEncryption(utils.OpenIntEnum): + NOT_ENCRYPTED = 0x00 + BROADCAST_CODE_REQUIRED = 0x01 + DECRYPTING = 0x02 + BAD_CODE = 0x03 + + source_id: int + source_address: hci.Address + source_adv_sid: int + broadcast_id: int + pa_sync_state: PeriodicAdvertisingSyncState + big_encryption: BigEncryption + bad_code: bytes + subgroups: List[SubgroupInfo] + + @classmethod + def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: + if not data: + return None + + source_id = data[0] + _, source_address = hci.Address.parse_address_preceded_by_type(data, 2) + source_adv_sid = data[8] + broadcast_id = int.from_bytes(data[9:12], "little") + pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12]) + big_encryption = cls.BigEncryption(data[13]) + if big_encryption == cls.BigEncryption.BAD_CODE: + bad_code = data[14:30] + subgroups = decode_subgroups(data[30:]) + else: + bad_code = b"" + subgroups = decode_subgroups(data[14:]) + + return cls( + source_id, + source_address, + source_adv_sid, + broadcast_id, + pa_sync_state, + big_encryption, + bad_code, + subgroups, + ) + + def __bytes__(self) -> bytes: + return ( + struct.pack( + " None: + await self.broadcast_audio_scan_control_point.write_value( + bytes(operation), with_response=True + ) + + async def remote_scan_started(self) -> None: + await self.send_control_point_operation(RemoteScanStartedOperation()) + + async def remote_scan_stopped(self) -> None: + await self.send_control_point_operation(RemoteScanStoppedOperation()) + + async def add_source( + self, + advertiser_address: hci.Address, + advertising_sid: int, + broadcast_id: int, + pa_sync: PeriodicAdvertisingSyncParams, + pa_interval: int, + subgroups: Sequence[SubgroupInfo], + ) -> None: + await self.send_control_point_operation( + AddSourceOperation( + advertiser_address, + advertising_sid, + broadcast_id, + pa_sync, + pa_interval, + subgroups, + ) + ) + + async def modify_source( + self, + source_id: int, + pa_sync: PeriodicAdvertisingSyncParams, + pa_interval: int, + subgroups: Sequence[SubgroupInfo], + ) -> None: + await self.send_control_point_operation( + ModifySourceOperation( + source_id, + pa_sync, + pa_interval, + subgroups, + ) + ) + + async def remove_source(self, source_id: int) -> None: + await self.send_control_point_operation(RemoveSourceOperation(source_id)) diff --git a/bumble/profiles/pacs.py b/bumble/profiles/pacs.py new file mode 100644 index 00000000..b36477dc --- /dev/null +++ b/bumble/profiles/pacs.py @@ -0,0 +1,206 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for + +"""LE Audio - Published Audio Capabilities Service""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import dataclasses +import logging +import struct +from typing import Optional, Sequence, Union + +from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType +from bumble import gatt +from bumble import gatt_client +from bumble import hci + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class PacRecord: + coding_format: hci.CodingFormat + codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] + # TODO: Parse Metadata + metadata: bytes = b'' + + @classmethod + def from_bytes(cls, data: bytes) -> PacRecord: + offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0) + codec_specific_capabilities_size = data[offset] + + offset += 1 + codec_specific_capabilities_bytes = data[ + offset : offset + codec_specific_capabilities_size + ] + offset += codec_specific_capabilities_size + metadata_size = data[offset] + metadata = data[offset : offset + metadata_size] + + codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] + if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: + codec_specific_capabilities = codec_specific_capabilities_bytes + else: + codec_specific_capabilities = CodecSpecificCapabilities.from_bytes( + codec_specific_capabilities_bytes + ) + + return PacRecord( + coding_format=coding_format, + codec_specific_capabilities=codec_specific_capabilities, + metadata=metadata, + ) + + def __bytes__(self) -> bytes: + capabilities_bytes = bytes(self.codec_specific_capabilities) + return ( + bytes(self.coding_format) + + bytes([len(capabilities_bytes)]) + + capabilities_bytes + + bytes([len(self.metadata)]) + + self.metadata + ) + + +# ----------------------------------------------------------------------------- +# Server +# ----------------------------------------------------------------------------- +class PublishedAudioCapabilitiesService(gatt.TemplateService): + UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE + + sink_pac: Optional[gatt.Characteristic] + sink_audio_locations: Optional[gatt.Characteristic] + source_pac: Optional[gatt.Characteristic] + source_audio_locations: Optional[gatt.Characteristic] + available_audio_contexts: gatt.Characteristic + supported_audio_contexts: gatt.Characteristic + + def __init__( + self, + supported_source_context: ContextType, + supported_sink_context: ContextType, + available_source_context: ContextType, + available_sink_context: ContextType, + sink_pac: Sequence[PacRecord] = (), + sink_audio_locations: Optional[AudioLocation] = None, + source_pac: Sequence[PacRecord] = (), + source_audio_locations: Optional[AudioLocation] = None, + ) -> None: + characteristics = [] + + self.supported_audio_contexts = gatt.Characteristic( + uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.READ, + permissions=gatt.Characteristic.Permissions.READABLE, + value=struct.pack(' None: - self.terminated.set_result(None) + if not self.terminated.done(): + self.terminated.set_result(None) + if self.sink: if hasattr(self.sink, 'on_transport_lost'): self.sink.on_transport_lost() diff --git a/tests/bap_test.py b/tests/bap_test.py index e276790c..0b57fcd2 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -23,8 +23,9 @@ import logging from bumble import device from bumble.hci import CodecID, CodingFormat -from bumble.profiles.bap import ( - AudioLocation, +from bumble.profiles.ascs import ( + AudioStreamControlService, + AudioStreamControlServiceProxy, AseStateMachine, ASE_Operation, ASE_Config_Codec, @@ -35,6 +36,9 @@ from bumble.profiles.bap import ( ASE_Receiver_Stop_Ready, ASE_Release, ASE_Update_Metadata, +) +from bumble.profiles.bap import ( + AudioLocation, SupportedFrameDuration, SupportedSamplingFrequency, SamplingFrequency, @@ -42,9 +46,9 @@ from bumble.profiles.bap import ( CodecSpecificCapabilities, CodecSpecificConfiguration, ContextType, +) +from bumble.profiles.pacs import ( PacRecord, - AudioStreamControlService, - AudioStreamControlServiceProxy, PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesServiceProxy, ) diff --git a/tests/bass_test.py b/tests/bass_test.py new file mode 100644 index 00000000..4cf7ec4e --- /dev/null +++ b/tests/bass_test.py @@ -0,0 +1,145 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import os +import logging + +from bumble import hci +from bumble.profiles import bass + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def basic_operation_check(operation: bass.ControlPointOperation) -> None: + serialized = bytes(operation) + parsed = bass.ControlPointOperation.from_bytes(serialized) + assert bytes(parsed) == serialized + + +# ----------------------------------------------------------------------------- +def test_operations() -> None: + op1 = bass.RemoteScanStoppedOperation() + basic_operation_check(op1) + + op2 = bass.RemoteScanStartedOperation() + basic_operation_check(op2) + + op3 = bass.AddSourceOperation( + hci.Address("AA:BB:CC:DD:EE:FF"), + 34, + 123456, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 456, + (), + ) + basic_operation_check(op3) + + op4 = bass.AddSourceOperation( + hci.Address("AA:BB:CC:DD:EE:FF"), + 34, + 123456, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 456, + ( + bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')), + bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')), + ), + ) + basic_operation_check(op4) + + op5 = bass.ModifySourceOperation( + 12, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 567, + (), + ) + basic_operation_check(op5) + + op6 = bass.ModifySourceOperation( + 12, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 567, + ( + bass.SubgroupInfo(6677, bytes.fromhex('112233')), + bass.SubgroupInfo(8899, bytes.fromhex('4567')), + ), + ) + basic_operation_check(op6) + + op7 = bass.SetBroadcastCodeOperation( + 7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf') + ) + basic_operation_check(op7) + + op8 = bass.RemoveSourceOperation(7) + basic_operation_check(op8) + + +# ----------------------------------------------------------------------------- +def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None: + serialized = bytes(brs) + parsed = bass.BroadcastReceiveState.from_bytes(serialized) + assert bytes(parsed) == serialized + + +def test_broadcast_receive_state() -> None: + subgroups = [ + bass.SubgroupInfo(6677, bytes.fromhex('112233')), + bass.SubgroupInfo(8899, bytes.fromhex('4567')), + ] + + brs1 = bass.BroadcastReceiveState( + 12, + hci.Address("AA:BB:CC:DD:EE:FF"), + 123, + 123456, + bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA, + bass.BroadcastReceiveState.BigEncryption.DECRYPTING, + b'', + subgroups, + ) + basic_broadcast_receive_state_check(brs1) + + brs2 = bass.BroadcastReceiveState( + 12, + hci.Address("AA:BB:CC:DD:EE:FF"), + 123, + 123456, + bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA, + bass.BroadcastReceiveState.BigEncryption.BAD_CODE, + bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'), + subgroups, + ) + basic_broadcast_receive_state_check(brs2) + + +# ----------------------------------------------------------------------------- +async def run(): + test_operations() + test_broadcast_receive_state() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run()) diff --git a/tests/import_test.py b/tests/import_test.py index 502f5936..95425112 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -41,6 +41,7 @@ def test_import(): ) from bumble.profiles import ( + ascs, bap, bass, battery_service, @@ -50,6 +51,7 @@ def test_import(): gap, heart_rate_service, le_audio, + pacs, pbp, vcp, ) @@ -73,6 +75,7 @@ def test_import(): assert transport assert utils + assert ascs assert bap assert bass assert battery_service @@ -82,6 +85,7 @@ def test_import(): assert gap assert heart_rate_service assert le_audio + assert pacs assert pbp assert vcp