From 7324d322fe3620be537f140707be3a55c3199fce Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Tue, 12 Nov 2024 17:38:17 +0800 Subject: [PATCH 1/2] BIG --- .vscode/settings.json | 7 + apps/auracast.py | 445 +++++++++++++++++++++++++++++++++++------- bumble/core.py | 5 +- bumble/device.py | 376 ++++++++++++++++++++++++++++++++++- bumble/hci.py | 68 ++++++- bumble/host.py | 53 ++++- 6 files changed, 875 insertions(+), 79 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 777c47b4..e0ff04e1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,9 +14,12 @@ "ASHA", "asyncio", "ATRAC", + "auracast", "avctp", "avdtp", "avrcp", + "biginfo", + "bigs", "bitpool", "bitstruct", "BSCP", @@ -36,6 +39,7 @@ "deregistration", "dhkey", "diversifier", + "ediv", "endianness", "ESCO", "Fitbit", @@ -47,6 +51,7 @@ "libc", "liblc", "libusb", + "maxs", "MITM", "MSBC", "NDIS", @@ -54,8 +59,10 @@ "NONBLOCK", "NONCONN", "OXIMETER", + "PDUS", "popleft", "PRAND", + "prefs", "protobuf", "psms", "pyee", diff --git a/apps/auracast.py b/apps/auracast.py index 2b645605..64e80aea 100644 --- a/apps/auracast.py +++ b/apps/auracast.py @@ -16,25 +16,35 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations + import asyncio import contextlib import dataclasses +import functools import logging import os +import wave +import itertools from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple import click import pyee +try: + import lc3 # type: ignore # pylint: disable=E0401 +except ImportError as e: + raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e + from bumble.colors import color -import bumble.company_ids -import bumble.core +from bumble import company_ids +from bumble import core +from bumble import gatt +from bumble import hci +from bumble.profiles import bap +from bumble.profiles import le_audio +from bumble.profiles import pbp +from bumble.profiles import bass import bumble.device -import bumble.gatt -import bumble.hci -import bumble.profiles.bap -import bumble.profiles.bass -import bumble.profiles.pbp import bumble.transport import bumble.utils @@ -49,7 +59,7 @@ logger = logging.getLogger(__name__) # Constants # ----------------------------------------------------------------------------- AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast' -AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5') +AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5') AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0 AURACAST_DEFAULT_ATT_MTU = 256 @@ -62,17 +72,12 @@ class BroadcastScanner(pyee.EventEmitter): class Broadcast(pyee.EventEmitter): name: str | None sync: bumble.device.PeriodicAdvertisingSync + broadcast_id: int rssi: int = 0 - public_broadcast_announcement: Optional[ - bumble.profiles.pbp.PublicBroadcastAnnouncement - ] = None - broadcast_audio_announcement: Optional[ - bumble.profiles.bap.BroadcastAudioAnnouncement - ] = None - basic_audio_announcement: Optional[ - bumble.profiles.bap.BasicAudioAnnouncement - ] = None - appearance: Optional[bumble.core.Appearance] = None + public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None + broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None + basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None + appearance: Optional[core.Appearance] = None biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None manufacturer_data: Optional[Tuple[str, bytes]] = None @@ -86,42 +91,36 @@ class BroadcastScanner(pyee.EventEmitter): def update(self, advertisement: bumble.device.Advertisement) -> None: self.rssi = advertisement.rssi for service_data in advertisement.data.get_all( - bumble.core.AdvertisingData.SERVICE_DATA + core.AdvertisingData.SERVICE_DATA ): assert isinstance(service_data, tuple) service_uuid, data = service_data assert isinstance(data, bytes) - if ( - service_uuid - == bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE - ): + if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE: self.public_broadcast_announcement = ( - bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data) + pbp.PublicBroadcastAnnouncement.from_bytes(data) ) continue - if ( - service_uuid - == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE - ): + if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE: self.broadcast_audio_announcement = ( - bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data) + bap.BroadcastAudioAnnouncement.from_bytes(data) ) continue self.appearance = advertisement.data.get( # type: ignore[assignment] - bumble.core.AdvertisingData.APPEARANCE + core.AdvertisingData.APPEARANCE ) if manufacturer_data := advertisement.data.get( - bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA + core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA ): assert isinstance(manufacturer_data, tuple) company_id = cast(int, manufacturer_data[0]) data = cast(bytes, manufacturer_data[1]) self.manufacturer_data = ( - bumble.company_ids.COMPANY_IDENTIFIERS.get( + company_ids.COMPANY_IDENTIFIERS.get( company_id, f'0x{company_id:04X}' ), data, @@ -232,15 +231,15 @@ class BroadcastScanner(pyee.EventEmitter): return for service_data in advertisement.data.get_all( - bumble.core.AdvertisingData.SERVICE_DATA + core.AdvertisingData.SERVICE_DATA ): assert isinstance(service_data, tuple) service_uuid, data = service_data assert isinstance(data, bytes) - if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE: + if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE: self.basic_audio_announcement = ( - bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data) + bap.BasicAudioAnnouncement.from_bytes(data) ) break @@ -262,7 +261,7 @@ class BroadcastScanner(pyee.EventEmitter): self.device = device self.filter_duplicates = filter_duplicates self.sync_timeout = sync_timeout - self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {} + self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]() device.on('advertisement', self.on_advertisement) async def start(self) -> None: @@ -277,33 +276,44 @@ class BroadcastScanner(pyee.EventEmitter): def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: if not ( ads := advertisement.data.get_all( - bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID + core.AdvertisingData.SERVICE_DATA_16_BIT_UUID ) ) or not ( - any( - ad - for ad in ads - if isinstance(ad, tuple) - and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE + broadcast_audio_announcement := next( + ( + ad + for ad in ads + if isinstance(ad, tuple) + and ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE + ), + None, ) ): return - broadcast_name = advertisement.data.get( - bumble.core.AdvertisingData.BROADCAST_NAME - ) + broadcast_name = advertisement.data.get(core.AdvertisingData.BROADCAST_NAME) assert isinstance(broadcast_name, str) or broadcast_name is None + assert isinstance(broadcast_audio_announcement[1], bytes) if broadcast := self.broadcasts.get(advertisement.address): broadcast.update(advertisement) return bumble.utils.AsyncRunner.spawn( - self.on_new_broadcast(broadcast_name, advertisement) + self.on_new_broadcast( + broadcast_name, + advertisement, + bap.BroadcastAudioAnnouncement.from_bytes( + broadcast_audio_announcement[1] + ).broadcast_id, + ) ) async def on_new_broadcast( - self, name: str | None, advertisement: bumble.device.Advertisement + self, + name: str | None, + advertisement: bumble.device.Advertisement, + broadcast_id: int, ) -> None: periodic_advertising_sync = await self.device.create_periodic_advertising_sync( advertiser_address=advertisement.address, @@ -311,7 +321,7 @@ class BroadcastScanner(pyee.EventEmitter): sync_timeout=self.sync_timeout, filter_duplicates=self.filter_duplicates, ) - broadcast = self.Broadcast(name, periodic_advertising_sync) + broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id) broadcast.update(advertisement) self.broadcasts[advertisement.address] = broadcast periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast)) @@ -323,10 +333,11 @@ class BroadcastScanner(pyee.EventEmitter): self.emit('broadcast_loss', broadcast) -class PrintingBroadcastScanner: +class PrintingBroadcastScanner(pyee.EventEmitter): def __init__( self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float ) -> None: + super().__init__() self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout) self.scanner.on('new_broadcast', self.on_new_broadcast) self.scanner.on('broadcast_loss', self.on_broadcast_loss) @@ -461,24 +472,26 @@ async def run_assist( await peer.request_mtu(mtu) # Get the BASS service - bass = await peer.discover_service_and_create_proxy( - bumble.profiles.bass.BroadcastAudioScanServiceProxy + bass_client = await peer.discover_service_and_create_proxy( + bass.BroadcastAudioScanServiceProxy ) # Check that the service was found - if not bass: + if not bass_client: print(color('!!! Broadcast Audio Scan Service not found', 'red')) return # Subscribe to and read the broadcast receive state characteristics - for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states): + for i, broadcast_receive_state in enumerate( + bass_client.broadcast_receive_states + ): try: await broadcast_receive_state.subscribe( lambda value, i=i: print( f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}" ) ) - except bumble.core.ProtocolError as error: + except core.ProtocolError as error: print( color( f'!!! Failed to subscribe to Broadcast Receive State characteristic:', @@ -497,7 +510,7 @@ async def run_assist( if command == 'add-source': # Find the requested broadcast - await bass.remote_scan_started() + await bass_client.remote_scan_started() if broadcast_name: print(color('Scanning for broadcast:', 'cyan'), broadcast_name) else: @@ -517,15 +530,15 @@ async def run_assist( # Add the source print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address) - await bass.add_source( + await bass_client.add_source( broadcast.sync.advertiser_address, broadcast.sync.sid, broadcast.broadcast_audio_announcement.broadcast_id, - bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE, 0xFFFF, [ - bumble.profiles.bass.SubgroupInfo( - bumble.profiles.bass.SubgroupInfo.ANY_BIS, + bass.SubgroupInfo( + bass.SubgroupInfo.ANY_BIS, bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), ) ], @@ -535,7 +548,7 @@ async def run_assist( await broadcast.sync.transfer(peer.connection) # Notify the sink that we're done scanning. - await bass.remote_scan_stopped() + await bass_client.remote_scan_stopped() await peer.sustain() return @@ -546,7 +559,7 @@ async def run_assist( return # Find the requested broadcast - await bass.remote_scan_started() + await bass_client.remote_scan_started() if broadcast_name: print(color('Scanning for broadcast:', 'cyan'), broadcast_name) else: @@ -569,13 +582,13 @@ async def run_assist( color('Modifying source:', 'blue'), source_id, ) - await bass.modify_source( + await bass_client.modify_source( source_id, - bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, 0xFFFF, [ - bumble.profiles.bass.SubgroupInfo( - bumble.profiles.bass.SubgroupInfo.ANY_BIS, + bass.SubgroupInfo( + bass.SubgroupInfo.ANY_BIS, bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), ) ], @@ -590,7 +603,7 @@ async def run_assist( # Remove the source print(color('Removing source:', 'blue'), source_id) - await bass.remove_source(source_id) + await bass_client.remove_source(source_id) await peer.sustain() return @@ -610,14 +623,244 @@ async def run_pair(transport: str, address: str) -> None: print("+++ Paired") +async def run_receive( + transport: str, + broadcast_id: int, + broadcast_code: str | None, + sync_timeout: float, + subgroup_index: int, +) -> None: + async with create_device(transport) as device: + if not device.supports_le_periodic_advertising: + print(color('Periodic advertising not supported', 'red')) + return + + scanner = BroadcastScanner(device, False, sync_timeout) + scan_result: asyncio.Future[BroadcastScanner.Broadcast] = ( + asyncio.get_running_loop().create_future() + ) + + def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None: + if scan_result.done(): + return + if broadcast.broadcast_id == broadcast_id: + scan_result.set_result(broadcast) + + scanner.on('new_broadcast', on_new_broadcast) + await scanner.start() + print('Start scanning...') + broadcast = await scan_result + print('Advertisement found:') + broadcast.print() + basic_audio_announcement_scanned = asyncio.Event() + + def on_change() -> None: + if ( + broadcast.basic_audio_announcement + and not basic_audio_announcement_scanned.is_set() + ): + basic_audio_announcement_scanned.set() + + broadcast.on('change', on_change) + if not broadcast.basic_audio_announcement: + print('Wait for Basic Audio Announcement...') + await basic_audio_announcement_scanned.wait() + print('Basic Audio Announcement found') + broadcast.print() + print('Stop scanning') + await scanner.stop() + print('Start sync to BIG') + + assert broadcast.basic_audio_announcement + subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index] + configuration = subgroup.codec_specific_configuration + assert configuration + assert (sampling_frequency := configuration.sampling_frequency) + assert (frame_duration := configuration.frame_duration) + + big_sync = await device.create_big_sync( + broadcast.sync, + bumble.device.BigSyncParameters( + big_sync_timeout=0x4000, + bis=[bis.index for bis in subgroup.bis], + broadcast_code=( + bytes.fromhex(broadcast_code) if broadcast_code else None + ), + ), + ) + num_bis = len(big_sync.bis_links) + decoder = lc3.Decoder( + frame_duration_us=frame_duration.us, + sample_rate_hz=sampling_frequency.hz, + num_channels=num_bis, + ) + sdus = [b''] * num_bis + subprocess = await asyncio.create_subprocess_shell( + f'stdbuf -i0 ffplay -ar 48000 -ac {num_bis} -f f32le pipe:0', + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + for i, bis_link in enumerate(big_sync.bis_links): + print(f'Setup ISO for BIS {bis_link.handle}') + + def sink(index: int, packet: hci.HCI_IsoDataPacket): + nonlocal sdus + sdus[index] = packet.iso_sdu_fragment + if all(sdus) and subprocess.stdin: + subprocess.stdin.write(decoder.decode(b''.join(sdus)).tobytes()) + sdus = [b''] * num_bis + + bis_link.sink = functools.partial(sink, i) + await device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=bis_link.handle, + data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST, + data_path_id=0, + codec_id=hci.CodingFormat(codec_id=hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ), + check_result=True, + ) + + terminated = asyncio.Event() + big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set()) + await terminated.wait() + + +async def run_broadcast( + transport: str, broadcast_id: int, broadcast_code: str | None, wav_file_path: str +) -> None: + async with create_device(transport) as device: + if not device.supports_le_periodic_advertising: + print(color('Periodic advertising not supported', 'red')) + return + + with wave.open(wav_file_path, 'rb') as wav: + print('Encoding wav file into lc3...') + encoder = lc3.Encoder( + frame_duration_us=10000, + sample_rate_hz=48000, + num_channels=2, + input_sample_rate_hz=wav.getframerate(), + ) + frames = list[bytes]() + while pcm := wav.readframes(encoder.get_frame_samples()): + frames.append( + encoder.encode(pcm, num_bytes=200, bit_depth=wav.getsampwidth() * 8) + ) + del encoder + print('Encoding complete.') + + basic_audio_announcement = bap.BasicAudioAnnouncement( + presentation_delay=40000, + subgroups=[ + bap.BasicAudioAnnouncement.Subgroup( + codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3), + codec_specific_configuration=bap.CodecSpecificConfiguration( + sampling_frequency=bap.SamplingFrequency.FREQ_48000, + frame_duration=bap.FrameDuration.DURATION_10000_US, + octets_per_codec_frame=100, + ), + metadata=le_audio.Metadata( + [ + le_audio.Metadata.Entry( + tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng' + ), + le_audio.Metadata.Entry( + tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco' + ), + ] + ), + bis=[ + bap.BasicAudioAnnouncement.BIS( + index=1, + codec_specific_configuration=bap.CodecSpecificConfiguration( + audio_channel_allocation=bap.AudioLocation.FRONT_LEFT + ), + ), + bap.BasicAudioAnnouncement.BIS( + index=2, + codec_specific_configuration=bap.CodecSpecificConfiguration( + audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT + ), + ), + ], + ) + ], + ) + broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id) + print('Start Advertising') + advertising_set = await device.create_advertising_set( + advertising_parameters=bumble.device.AdvertisingParameters( + advertising_event_properties=bumble.device.AdvertisingEventProperties( + is_connectable=False + ), + primary_advertising_interval_min=100, + primary_advertising_interval_max=200, + ), + advertising_data=( + broadcast_audio_announcement.get_advertising_data() + + bytes( + core.AdvertisingData( + [(core.AdvertisingData.BROADCAST_NAME, b'Bumble Auracast')] + ) + ) + ), + periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters( + periodic_advertising_interval_min=80, + periodic_advertising_interval_max=160, + ), + periodic_advertising_data=basic_audio_announcement.get_advertising_data(), + auto_restart=True, + auto_start=True, + ) + print('Start Periodic Advertising') + await advertising_set.start_periodic() + print('Setup BIG') + big = await device.create_big( + advertising_set, + parameters=bumble.device.BigParameters( + num_bis=2, + sdu_interval=10000, + max_sdu=100, + max_transport_latency=65, + rtn=4, + broadcast_code=( + bytes.fromhex(broadcast_code) if broadcast_code else None + ), + ), + ) + print('Setup ISO Data Path') + for bis_link in big.bis_links: + await device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=bis_link.handle, + data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER, + data_path_id=0, + codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ), + check_result=True, + ) + + for frame in itertools.cycle(frames): + mid = len(frame) // 2 + big.bis_links[0].write(frame[:mid]) + big.bis_links[1].write(frame[mid:]) + await asyncio.sleep(0.009) + + def run_async(async_command: Coroutine) -> None: try: asyncio.run(async_command) - except bumble.core.ProtocolError as error: + except core.ProtocolError as error: if error.error_namespace == 'att' and error.error_code in list( - bumble.profiles.bass.ApplicationError + bass.ApplicationError ): - message = bumble.profiles.bass.ApplicationError(error.error_code).name + message = bass.ApplicationError(error.error_code).name else: message = str(error) @@ -631,9 +874,7 @@ def run_async(async_command: Coroutine) -> None: # ----------------------------------------------------------------------------- @click.group() @click.pass_context -def auracast( - ctx, -): +def auracast(ctx): ctx.ensure_object(dict) @@ -691,6 +932,66 @@ def pair(ctx, transport, address): run_async(run_pair(transport, address)) +@auracast.command('receive') +@click.argument('transport') +@click.argument('broadcast_id', type=int) +@click.option( + '--broadcast-code', + metavar='BROADCAST_CODE', + type=str, + help='Broadcast encryption code in hex format', +) +@click.option( + '--sync-timeout', + metavar='SYNC_TIMEOUT', + type=float, + default=AURACAST_DEFAULT_SYNC_TIMEOUT, + help='Sync timeout (in seconds)', +) +@click.option( + '--subgroup', + metavar='SUBGROUP', + type=int, + default=0, + help='Index of Subgroup', +) +@click.pass_context +def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout, subgroup): + """Receive a broadcast source""" + run_async( + run_receive(transport, broadcast_id, broadcast_code, sync_timeout, subgroup) + ) + + +@auracast.command('broadcast') +@click.argument('transport') +@click.argument('wav_file_path', type=str) +@click.option( + '--broadcast-id', + metavar='BROADCAST_ID', + type=int, + default=123456, + help='Broadcast ID', +) +@click.option( + '--broadcast-code', + metavar='BROADCAST_CODE', + type=str, + help='Broadcast encryption code in hex format', +) +@click.pass_context +def broadcast(ctx, transport, broadcast_id, broadcast_code, wav_file_path): + """Start a broadcast as a source.""" + run_async( + run_broadcast( + transport=transport, + broadcast_id=broadcast_id, + broadcast_code=broadcast_code, + wav_file_path=wav_file_path, + ) + ) + + def main(): logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) auracast() diff --git a/bumble/core.py b/bumble/core.py index f6d42dd5..0161a73c 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -1501,7 +1501,10 @@ class AdvertisingData: ad_data_str = f'"{ad_data.decode("utf-8")}"' elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME: ad_type_str = 'Complete Local Name' - ad_data_str = f'"{ad_data.decode("utf-8")}"' + try: + ad_data_str = f'"{ad_data.decode("utf-8")}"' + except UnicodeDecodeError: + ad_data_str = ad_data.hex() elif ad_type == AdvertisingData.TX_POWER_LEVEL: ad_type_str = 'TX Power Level' ad_data_str = str(ad_data[0]) diff --git a/bumble/device.py b/bumble/device.py index 866ef166..9cb5748c 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -17,7 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from contextlib import ( asynccontextmanager, AsyncExitStack, @@ -119,6 +119,8 @@ DEVICE_MIN_LE_RSSI = -127 DEVICE_MAX_LE_RSSI = 20 DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE = 0x00 DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE = 0xEF +DEVICE_MIN_BIG_HANDLE = 0x00 +DEVICE_MAX_BIG_HANDLE = 0xEF DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00' DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms @@ -992,6 +994,130 @@ class PeriodicAdvertisingSync(EventEmitter): ) +# ----------------------------------------------------------------------------- +@dataclass +class BigParameters: + num_bis: int + sdu_interval: int + max_sdu: int + max_transport_latency: int + rtn: int + phy: hci.PhyBit = hci.PhyBit.LE_2M + packing: int = 0 + framing: int = 0 + broadcast_code: bytes | None = None + + +# ----------------------------------------------------------------------------- +@dataclass +class Big(EventEmitter): + class State(IntEnum): + PENDING = 0 + ACTIVE = 1 + TERMINATED = 2 + + class Event(str, Enum): + ESTABLISHMENT = 'establishment' + ESTABLISHMENT_FAILURE = 'establishment_failure' + TERMINATION = 'termination' + + big_handle: int + advertising_set: AdvertisingSet + parameters: BigParameters + state: State = State.PENDING + + # Attributes provided by BIG Create Complete event + big_sync_delay: int = 0 + transport_latency_big: int = 0 + phy: int = 0 + nse: int = 0 + bn: int = 0 + pto: int = 0 + irc: int = 0 + max_pdu: int = 0 + iso_interval: int = 0 + bis_links: Sequence[BisLink] = () + + def __post_init__(self) -> None: + super().__init__() + self.device = self.advertising_set.device + + async def terminate( + self, + reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, + ) -> None: + if self.state != Big.State.ACTIVE: + logger.error('BIG %d is not active.', self.big_handle) + return + + with closing(EventWatcher()) as watcher: + terminated = asyncio.Event() + watcher.once(self, Big.Event.TERMINATION, lambda _: terminated.set()) + await self.device.send_command( + hci.HCI_LE_Terminate_BIG_Command( + big_handle=self.big_handle, reason=reason + ), + check_result=True, + ) + await terminated.wait() + + +# ----------------------------------------------------------------------------- +@dataclass +class BigSyncParameters: + big_sync_timeout: int + bis: Sequence[int] + mse: int = 0 + broadcast_code: bytes | None = None + + +# ----------------------------------------------------------------------------- +@dataclass +class BigSync(EventEmitter): + class State(IntEnum): + PENDING = 0 + ACTIVE = 1 + TERMINATED = 2 + + class Event(str, Enum): + ESTABLISHMENT = 'establishment' + ESTABLISHMENT_FAILURE = 'establishment_failure' + TERMINATION = 'termination' + + big_handle: int + pa_sync: PeriodicAdvertisingSync + parameters: BigSyncParameters + state: State = State.PENDING + + # Attributes provided by BIG Create Sync Complete event + transport_latency_big: int = 0 + nse: int = 0 + bn: int = 0 + pto: int = 0 + irc: int = 0 + max_pdu: int = 0 + iso_interval: int = 0 + bis_links: Sequence[BisLink] = () + + def __post_init__(self) -> None: + super().__init__() + self.device = self.pa_sync.device + + async def terminate(self) -> None: + if self.state != BigSync.State.ACTIVE: + logger.error('BIG Sync %d is not active.', self.big_handle) + return + + with closing(EventWatcher()) as watcher: + terminated = asyncio.Event() + watcher.once(self, BigSync.Event.TERMINATION, lambda _: terminated.set()) + await self.device.send_command( + hci.HCI_LE_BIG_Terminate_Sync_Command(big_handle=self.big_handle), + check_result=True, + ) + await terminated.wait() + + # ----------------------------------------------------------------------------- class LePhyOptions: # Coded PHY preference @@ -1225,6 +1351,32 @@ class CisLink(CompositeEventEmitter): await self.device.disconnect(self, reason) +# ----------------------------------------------------------------------------- +@dataclass +class BisLink: + handle: int + big: Big | BigSync + sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None + + def __post_init__(self) -> None: + self.device = self.big.device + self.packet_sequence_number = 0 + + def write(self, sdu: bytes) -> None: + self.device.host.send_hci_packet( + hci.HCI_IsoDataPacket( + connection_handle=self.handle, + data_total_length=len(sdu) + 4, + packet_sequence_number=self.packet_sequence_number, + pb_flag=0b10, + packet_status_flag=0, + iso_sdu_length=len(sdu), + iso_sdu_fragment=sdu, + ) + ) + self.packet_sequence_number += 1 + + # ----------------------------------------------------------------------------- class Connection(CompositeEventEmitter): device: Device @@ -1713,6 +1865,9 @@ class Device(CompositeEventEmitter): legacy_advertiser: Optional[LegacyAdvertiser] sco_links: Dict[int, ScoLink] cis_links: Dict[int, CisLink] + bigs = dict[int, Big]() + bis_links = dict[int, BisLink]() + big_syncs = dict[int, BigSync]() _pending_cis: Dict[int, Tuple[int, int]] @composite_listener @@ -2009,6 +2164,17 @@ class Device(CompositeEventEmitter): None, ) + def next_big_handle(self) -> int | None: + return next( + ( + handle + for handle in range(DEVICE_MIN_BIG_HANDLE, DEVICE_MAX_BIG_HANDLE + 1) + if handle + not in itertools.chain(self.bigs.keys(), self.big_syncs.keys()) + ), + None, + ) + @deprecated("Please use create_l2cap_server()") def register_l2cap_server(self, psm, server) -> int: return self.l2cap_channel_manager.register_server(psm, server) @@ -4112,6 +4278,106 @@ class Device(CompositeEventEmitter): check_result=True, ) + # [LE only] + @experimental('Only for testing.') + async def create_big( + self, advertising_set: AdvertisingSet, parameters: BigParameters + ) -> Big: + if (big_handle := self.next_big_handle()) is None: + raise core.OutOfResourcesError("All valid BIG handles already in use") + + with closing(EventWatcher()) as watcher: + big = Big( + big_handle=big_handle, + parameters=parameters, + advertising_set=advertising_set, + ) + self.bigs[big_handle] = big + established = asyncio.get_running_loop().create_future() + watcher.once( + big, big.Event.ESTABLISHMENT, lambda: established.set_result(None) + ) + watcher.once( + big, + big.Event.ESTABLISHMENT_FAILURE, + lambda status: established.set_exception(hci.HCI_Error(status)), + ) + + try: + await self.send_command( + hci.HCI_LE_Create_BIG_Command( + big_handle=big_handle, + advertising_handle=advertising_set.advertising_handle, + num_bis=parameters.num_bis, + sdu_interval=parameters.sdu_interval, + max_sdu=parameters.max_sdu, + max_transport_latency=parameters.max_transport_latency, + rtn=parameters.rtn, + phy=parameters.phy, + packing=parameters.packing, + framing=parameters.framing, + encryption=1 if parameters.broadcast_code else 0, + broadcast_code=parameters.broadcast_code or bytes(16), + ), + check_result=True, + ) + await established + except hci.HCI_Error: + del self.bigs[big_handle] + raise + + return big + + # [LE only] + @experimental('Only for testing.') + async def create_big_sync( + self, pa_sync: PeriodicAdvertisingSync, parameters: BigSyncParameters + ) -> BigSync: + if (big_handle := self.next_big_handle()) is None: + raise core.OutOfResourcesError("All valid BIG handles already in use") + + if (pa_sync_handle := pa_sync.sync_handle) is None: + raise core.InvalidStateError("PA Sync is not established") + + with closing(EventWatcher()) as watcher: + big_sync = BigSync( + big_handle=big_handle, + parameters=parameters, + pa_sync=pa_sync, + ) + self.big_syncs[big_handle] = big_sync + established = asyncio.get_running_loop().create_future() + watcher.once( + big_sync, + big_sync.Event.ESTABLISHMENT, + lambda: established.set_result(None), + ) + watcher.once( + big_sync, + big_sync.Event.ESTABLISHMENT_FAILURE, + lambda status: established.set_exception(hci.HCI_Error(status)), + ) + + try: + await self.send_command( + hci.HCI_LE_BIG_Create_Sync_Command( + big_handle=big_handle, + sync_handle=pa_sync_handle, + encryption=1 if parameters.broadcast_code else 0, + broadcast_code=parameters.broadcast_code or bytes(16), + mse=parameters.mse, + big_sync_timeout=parameters.big_sync_timeout, + bis=parameters.bis, + ), + check_result=True, + ) + await established + except hci.HCI_Error: + del self.big_syncs[big_handle] + raise + + return big_sync + async def get_remote_le_features(self, connection: Connection) -> hci.LeFeatureMask: """[LE Only] Reads remote LE supported features. @@ -4233,6 +4499,112 @@ class Device(CompositeEventEmitter): ) self.connecting_extended_advertising_sets[connection_handle] = advertising_set + @host_event_handler + def on_big_establishment( + self, + status: int, + big_handle: int, + bis_handles: List[int], + big_sync_delay: int, + transport_latency_big: int, + phy: int, + nse: int, + bn: int, + pto: int, + irc: int, + max_pdu: int, + iso_interval: int, + ) -> None: + if not (big := self.bigs.get(big_handle)): + logger.warning('BIG %d not found', big_handle) + return + + if status != hci.HCI_SUCCESS: + del self.bigs[big_handle] + logger.debug('Unable to create BIG %d', big_handle) + big.state = Big.State.TERMINATED + big.emit(Big.Event.ESTABLISHMENT_FAILURE, status) + return + + big.bis_links = [BisLink(handle=handle, big=big) for handle in bis_handles] + big.big_sync_delay = big_sync_delay + big.transport_latency_big = transport_latency_big + big.phy = phy + big.nse = nse + big.bn = bn + big.pto = pto + big.irc = irc + big.max_pdu = max_pdu + big.iso_interval = iso_interval + big.state = Big.State.ACTIVE + + for bis_link in big.bis_links: + self.bis_links[bis_link.handle] = bis_link + big.emit(Big.Event.ESTABLISHMENT) + + @host_event_handler + def on_big_termination(self, reason: int, big_handle: int) -> None: + if not (big := self.bigs.pop(big_handle, None)): + logger.warning('BIG %d not found', big_handle) + return + + big.state = Big.State.TERMINATED + for bis_link in big.bis_links: + self.bis_links.pop(bis_link.handle, None) + big.emit(Big.Event.TERMINATION, reason) + + @host_event_handler + def on_big_sync_establishment( + self, + status: int, + big_handle: int, + transport_latency_big: int, + nse: int, + bn: int, + pto: int, + irc: int, + max_pdu: int, + iso_interval: int, + bis_handles: list[int], + ) -> None: + if not (big_sync := self.big_syncs.get(big_handle)): + logger.warning('BIG Sync %d not found', big_handle) + return + + if status != hci.HCI_SUCCESS: + del self.big_syncs[big_handle] + logger.debug('Unable to create BIG Sync %d', big_handle) + big_sync.state = BigSync.State.TERMINATED + big_sync.emit(BigSync.Event.ESTABLISHMENT_FAILURE, status) + return + + big_sync.transport_latency_big = transport_latency_big + big_sync.nse = nse + big_sync.bn = bn + big_sync.pto = pto + big_sync.irc = irc + big_sync.max_pdu = max_pdu + big_sync.iso_interval = iso_interval + big_sync.bis_links = [ + BisLink(handle=handle, big=big_sync) for handle in bis_handles + ] + big_sync.state = BigSync.State.ACTIVE + + for bis_link in big_sync.bis_links: + self.bis_links[bis_link.handle] = bis_link + big_sync.emit(BigSync.Event.ESTABLISHMENT) + + @host_event_handler + def on_big_sync_lost(self, big_handle: int, reason: int) -> None: + if not (big_sync := self.big_syncs.pop(big_handle, None)): + logger.warning('BIG %d not found', big_handle) + return + + for bis_link in big_sync.bis_links: + self.bis_links.pop(bis_link.handle, None) + big_sync.state = BigSync.State.TERMINATED + big_sync.emit(BigSync.Event.TERMINATION, reason) + def _complete_le_extended_advertising_connection( self, connection: Connection, advertising_set: AdvertisingSet ) -> None: @@ -4879,6 +5251,8 @@ class Device(CompositeEventEmitter): def on_iso_packet(self, handle: int, packet: hci.HCI_IsoDataPacket) -> None: if (cis_link := self.cis_links.get(handle)) and cis_link.sink: cis_link.sink(packet) + elif (bis_link := self.bis_links.get(handle)) and bis_link.sink: + bis_link.sink(packet) @host_event_handler @with_connection_from_handle diff --git a/bumble/hci.py b/bumble/hci.py index 24f91fa4..92000488 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4934,7 +4934,7 @@ class HCI_LE_Create_BIG_Command(HCI_Command): packing: int framing: int encryption: int - broadcast_code: int + broadcast_code: bytes # ----------------------------------------------------------------------------- @@ -5823,7 +5823,7 @@ class HCI_LE_Periodic_Advertising_Sync_Lost_Event(HCI_LE_Meta_Event): # ----------------------------------------------------------------------------- @HCI_LE_Meta_Event.event( [ - ('status', 1), + ('status', STATUS_SPEC), ('advertising_handle', 1), ('connection_handle', 2), ('num_completed_extended_advertising_events', 1), @@ -5906,6 +5906,70 @@ class HCI_LE_CIS_Request_Event(HCI_LE_Meta_Event): ''' +# ----------------------------------------------------------------------------- +@HCI_LE_Meta_Event.event( + [ + ('status', STATUS_SPEC), + ('big_handle', 1), + ('big_sync_delay', 3), + ('transport_latency_big', 3), + ('phy', 1), + ('nse', 1), + ('bn', 1), + ('pto', 1), + ('irc', 1), + ('max_pdu', 2), + ('iso_interval', 2), + [('connection_handle', 2)], + ] +) +class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.27 LE Create BIG Complete Event + ''' + + +# ----------------------------------------------------------------------------- +@HCI_LE_Meta_Event.event([('big_handle', 1), ('reason', 1)]) +class HCI_LE_Terminate_BIG_Complete_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.28 LE Terminate BIG Complete Event + ''' + + +# ----------------------------------------------------------------------------- + + +@HCI_LE_Meta_Event.event( + [ + ('status', STATUS_SPEC), + ('big_handle', 1), + ('transport_latency_big', 3), + ('nse', 1), + ('bn', 1), + ('pto', 1), + ('irc', 1), + ('max_pdu', 2), + ('iso_interval', 2), + [('connection_handle', 2)], + ] +) +class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.29 LE BIG Sync Established event + ''' + + +# ----------------------------------------------------------------------------- + + +@HCI_LE_Meta_Event.event([('big_handle', 1), ('reason', 1)]) +class HCI_LE_BIG_Sync_Lost_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.30 LE BIG Sync Lost event + ''' + + # ----------------------------------------------------------------------------- @HCI_LE_Meta_Event.event( [ diff --git a/bumble/host.py b/bumble/host.py index 57d05fa6..1ce4263a 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,6 +21,7 @@ import collections import dataclasses import logging import struct +import itertools from typing import ( Any, @@ -149,6 +150,7 @@ class Host(AbortableEventEmitter): connections: Dict[int, Connection] cis_links: Dict[int, CisLink] sco_links: Dict[int, ScoLink] + bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles acl_packet_queue: Optional[AclPacketQueue] = None le_acl_packet_queue: Optional[AclPacketQueue] = None hci_sink: Optional[TransportSink] = None @@ -733,9 +735,10 @@ class Host(AbortableEventEmitter): ): if connection := self.connections.get(connection_handle): connection.acl_packet_queue.on_packets_completed(num_completed_packets) - elif not ( - self.cis_links.get(connection_handle) - or self.sco_links.get(connection_handle) + elif connection_handle not in itertools.chain( + self.cis_links.keys(), + self.sco_links.keys(), + itertools.chain.from_iterable(self.bigs.values()), ): logger.warning( 'received packet completion event for unknown handle ' @@ -953,6 +956,50 @@ class Host(AbortableEventEmitter): event.cis_id, ) + def on_hci_le_create_big_complete_event(self, event): + self.bigs[event.big_handle] = set(event.connection_handle) + self.emit( + 'big_establishment', + event.status, + event.big_handle, + event.connection_handle, + event.big_sync_delay, + event.transport_latency_big, + event.phy, + event.nse, + event.bn, + event.pto, + event.irc, + event.max_pdu, + event.iso_interval, + ) + + def on_hci_le_big_sync_established_event(self, event): + self.emit( + 'big_sync_establishment', + event.status, + event.big_handle, + event.transport_latency_big, + event.nse, + event.bn, + event.pto, + event.irc, + event.max_pdu, + event.iso_interval, + event.connection_handle, + ) + + def on_hci_le_big_sync_lost_event(self, event): + self.emit( + 'big_sync_lost', + event.big_handle, + event.reason, + ) + + def on_hci_le_terminate_big_complete_event(self, event): + self.bigs.pop(event.big_handle) + self.emit('big_termination', event.reason, event.big_handle) + def on_hci_le_cis_established_event(self, event): # The remaining parameters are unused for now. if event.status == hci.HCI_SUCCESS: From d238dd4059c803d75c42f9b5f9cf841401c7f47b Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 23 Dec 2024 17:01:11 +0800 Subject: [PATCH 2/2] Use dynamic sample rate --- apps/auracast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/auracast.py b/apps/auracast.py index 64e80aea..5d8a6bf5 100644 --- a/apps/auracast.py +++ b/apps/auracast.py @@ -696,7 +696,7 @@ async def run_receive( ) sdus = [b''] * num_bis subprocess = await asyncio.create_subprocess_shell( - f'stdbuf -i0 ffplay -ar 48000 -ac {num_bis} -f f32le pipe:0', + f'stdbuf -i0 ffplay -ar {sampling_frequency.hz} -ac {num_bis} -f f32le pipe:0', stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,