This commit is contained in:
Josh Wu
2024-11-12 17:38:17 +08:00
parent 80d60aaf15
commit 7324d322fe
6 changed files with 875 additions and 79 deletions

View File

@@ -16,25 +16,35 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import functools
import logging
import os
import wave
import itertools
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
import click
import pyee
try:
import lc3 # type: ignore # pylint: disable=E0401
except ImportError as e:
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
from bumble.colors import color
import bumble.company_ids
import bumble.core
from bumble import company_ids
from bumble import core
from bumble import gatt
from bumble import hci
from bumble.profiles import bap
from bumble.profiles import le_audio
from bumble.profiles import pbp
from bumble.profiles import bass
import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
@@ -49,7 +59,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256
@@ -62,17 +72,12 @@ class BroadcastScanner(pyee.EventEmitter):
class Broadcast(pyee.EventEmitter):
name: str | None
sync: bumble.device.PeriodicAdvertisingSync
broadcast_id: int
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
] = None
broadcast_audio_announcement: Optional[
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None
broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None
basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None
appearance: Optional[core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None
@@ -86,42 +91,36 @@ class BroadcastScanner(pyee.EventEmitter):
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if (
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE:
self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
pbp.PublicBroadcastAnnouncement.from_bytes(data)
)
continue
if (
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE:
self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
bap.BroadcastAudioAnnouncement.from_bytes(data)
)
continue
self.appearance = advertisement.data.get( # type: ignore[assignment]
bumble.core.AdvertisingData.APPEARANCE
core.AdvertisingData.APPEARANCE
)
if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
):
assert isinstance(manufacturer_data, tuple)
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get(
company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
),
data,
@@ -232,15 +231,15 @@ class BroadcastScanner(pyee.EventEmitter):
return
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
bap.BasicAudioAnnouncement.from_bytes(data)
)
break
@@ -262,7 +261,7 @@ class BroadcastScanner(pyee.EventEmitter):
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]()
device.on('advertisement', self.on_advertisement)
async def start(self) -> None:
@@ -277,33 +276,44 @@ class BroadcastScanner(pyee.EventEmitter):
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if not (
ads := advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
)
) or not (
any(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
broadcast_audio_announcement := next(
(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
),
None,
)
):
return
broadcast_name = advertisement.data.get(
bumble.core.AdvertisingData.BROADCAST_NAME
)
broadcast_name = advertisement.data.get(core.AdvertisingData.BROADCAST_NAME)
assert isinstance(broadcast_name, str) or broadcast_name is None
assert isinstance(broadcast_audio_announcement[1], bytes)
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
return
bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
self.on_new_broadcast(
broadcast_name,
advertisement,
bap.BroadcastAudioAnnouncement.from_bytes(
broadcast_audio_announcement[1]
).broadcast_id,
)
)
async def on_new_broadcast(
self, name: str | None, advertisement: bumble.device.Advertisement
self,
name: str | None,
advertisement: bumble.device.Advertisement,
broadcast_id: int,
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
@@ -311,7 +321,7 @@ class BroadcastScanner(pyee.EventEmitter):
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(name, periodic_advertising_sync)
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
@@ -323,10 +333,11 @@ class BroadcastScanner(pyee.EventEmitter):
self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner:
class PrintingBroadcastScanner(pyee.EventEmitter):
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
super().__init__()
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
@@ -461,24 +472,26 @@ async def run_assist(
await peer.request_mtu(mtu)
# Get the BASS service
bass = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy
bass_client = await peer.discover_service_and_create_proxy(
bass.BroadcastAudioScanServiceProxy
)
# Check that the service was found
if not bass:
if not bass_client:
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return
# Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
for i, broadcast_receive_state in enumerate(
bass_client.broadcast_receive_states
):
try:
await broadcast_receive_state.subscribe(
lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
)
)
except bumble.core.ProtocolError as error:
except core.ProtocolError as error:
print(
color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
@@ -497,7 +510,7 @@ async def run_assist(
if command == 'add-source':
# Find the requested broadcast
await bass.remote_scan_started()
await bass_client.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
@@ -517,15 +530,15 @@ async def run_assist(
# Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source(
await bass_client.add_source(
broadcast.sync.advertiser_address,
broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bass.SubgroupInfo(
bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
@@ -535,7 +548,7 @@ async def run_assist(
await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning.
await bass.remote_scan_stopped()
await bass_client.remote_scan_stopped()
await peer.sustain()
return
@@ -546,7 +559,7 @@ async def run_assist(
return
# Find the requested broadcast
await bass.remote_scan_started()
await bass_client.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
@@ -569,13 +582,13 @@ async def run_assist(
color('Modifying source:', 'blue'),
source_id,
)
await bass.modify_source(
await bass_client.modify_source(
source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bass.SubgroupInfo(
bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
@@ -590,7 +603,7 @@ async def run_assist(
# Remove the source
print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id)
await bass_client.remove_source(source_id)
await peer.sustain()
return
@@ -610,14 +623,244 @@ async def run_pair(transport: str, address: str) -> None:
print("+++ Paired")
async def run_receive(
transport: str,
broadcast_id: int,
broadcast_code: str | None,
sync_timeout: float,
subgroup_index: int,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = BroadcastScanner(device, False, sync_timeout)
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
asyncio.get_running_loop().create_future()
)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if scan_result.done():
return
if broadcast.broadcast_id == broadcast_id:
scan_result.set_result(broadcast)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
print('Start scanning...')
broadcast = await scan_result
print('Advertisement found:')
broadcast.print()
basic_audio_announcement_scanned = asyncio.Event()
def on_change() -> None:
if (
broadcast.basic_audio_announcement
and not basic_audio_announcement_scanned.is_set()
):
basic_audio_announcement_scanned.set()
broadcast.on('change', on_change)
if not broadcast.basic_audio_announcement:
print('Wait for Basic Audio Announcement...')
await basic_audio_announcement_scanned.wait()
print('Basic Audio Announcement found')
broadcast.print()
print('Stop scanning')
await scanner.stop()
print('Start sync to BIG')
assert broadcast.basic_audio_announcement
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
configuration = subgroup.codec_specific_configuration
assert configuration
assert (sampling_frequency := configuration.sampling_frequency)
assert (frame_duration := configuration.frame_duration)
big_sync = await device.create_big_sync(
broadcast.sync,
bumble.device.BigSyncParameters(
big_sync_timeout=0x4000,
bis=[bis.index for bis in subgroup.bis],
broadcast_code=(
bytes.fromhex(broadcast_code) if broadcast_code else None
),
),
)
num_bis = len(big_sync.bis_links)
decoder = lc3.Decoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=num_bis,
)
sdus = [b''] * num_bis
subprocess = await asyncio.create_subprocess_shell(
f'stdbuf -i0 ffplay -ar 48000 -ac {num_bis} -f f32le pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
for i, bis_link in enumerate(big_sync.bis_links):
print(f'Setup ISO for BIS {bis_link.handle}')
def sink(index: int, packet: hci.HCI_IsoDataPacket):
nonlocal sdus
sdus[index] = packet.iso_sdu_fragment
if all(sdus) and subprocess.stdin:
subprocess.stdin.write(decoder.decode(b''.join(sdus)).tobytes())
sdus = [b''] * num_bis
bis_link.sink = functools.partial(sink, i)
await device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=bis_link.handle,
data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST,
data_path_id=0,
codec_id=hci.CodingFormat(codec_id=hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
)
terminated = asyncio.Event()
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
await terminated.wait()
async def run_broadcast(
transport: str, broadcast_id: int, broadcast_code: str | None, wav_file_path: str
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
with wave.open(wav_file_path, 'rb') as wav:
print('Encoding wav file into lc3...')
encoder = lc3.Encoder(
frame_duration_us=10000,
sample_rate_hz=48000,
num_channels=2,
input_sample_rate_hz=wav.getframerate(),
)
frames = list[bytes]()
while pcm := wav.readframes(encoder.get_frame_samples()):
frames.append(
encoder.encode(pcm, num_bytes=200, bit_depth=wav.getsampwidth() * 8)
)
del encoder
print('Encoding complete.')
basic_audio_announcement = bap.BasicAudioAnnouncement(
presentation_delay=40000,
subgroups=[
bap.BasicAudioAnnouncement.Subgroup(
codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3),
codec_specific_configuration=bap.CodecSpecificConfiguration(
sampling_frequency=bap.SamplingFrequency.FREQ_48000,
frame_duration=bap.FrameDuration.DURATION_10000_US,
octets_per_codec_frame=100,
),
metadata=le_audio.Metadata(
[
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng'
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco'
),
]
),
bis=[
bap.BasicAudioAnnouncement.BIS(
index=1,
codec_specific_configuration=bap.CodecSpecificConfiguration(
audio_channel_allocation=bap.AudioLocation.FRONT_LEFT
),
),
bap.BasicAudioAnnouncement.BIS(
index=2,
codec_specific_configuration=bap.CodecSpecificConfiguration(
audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT
),
),
],
)
],
)
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
print('Start Advertising')
advertising_set = await device.create_advertising_set(
advertising_parameters=bumble.device.AdvertisingParameters(
advertising_event_properties=bumble.device.AdvertisingEventProperties(
is_connectable=False
),
primary_advertising_interval_min=100,
primary_advertising_interval_max=200,
),
advertising_data=(
broadcast_audio_announcement.get_advertising_data()
+ bytes(
core.AdvertisingData(
[(core.AdvertisingData.BROADCAST_NAME, b'Bumble Auracast')]
)
)
),
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
periodic_advertising_interval_min=80,
periodic_advertising_interval_max=160,
),
periodic_advertising_data=basic_audio_announcement.get_advertising_data(),
auto_restart=True,
auto_start=True,
)
print('Start Periodic Advertising')
await advertising_set.start_periodic()
print('Setup BIG')
big = await device.create_big(
advertising_set,
parameters=bumble.device.BigParameters(
num_bis=2,
sdu_interval=10000,
max_sdu=100,
max_transport_latency=65,
rtn=4,
broadcast_code=(
bytes.fromhex(broadcast_code) if broadcast_code else None
),
),
)
print('Setup ISO Data Path')
for bis_link in big.bis_links:
await device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=bis_link.handle,
data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER,
data_path_id=0,
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
)
for frame in itertools.cycle(frames):
mid = len(frame) // 2
big.bis_links[0].write(frame[:mid])
big.bis_links[1].write(frame[mid:])
await asyncio.sleep(0.009)
def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
except bumble.core.ProtocolError as error:
except core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError
bass.ApplicationError
):
message = bumble.profiles.bass.ApplicationError(error.error_code).name
message = bass.ApplicationError(error.error_code).name
else:
message = str(error)
@@ -631,9 +874,7 @@ def run_async(async_command: Coroutine) -> None:
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
def auracast(ctx):
ctx.ensure_object(dict)
@@ -691,6 +932,66 @@ def pair(ctx, transport, address):
run_async(run_pair(transport, address))
@auracast.command('receive')
@click.argument('transport')
@click.argument('broadcast_id', type=int)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
type=str,
help='Broadcast encryption code in hex format',
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)',
)
@click.option(
'--subgroup',
metavar='SUBGROUP',
type=int,
default=0,
help='Index of Subgroup',
)
@click.pass_context
def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout, subgroup):
"""Receive a broadcast source"""
run_async(
run_receive(transport, broadcast_id, broadcast_code, sync_timeout, subgroup)
)
@auracast.command('broadcast')
@click.argument('transport')
@click.argument('wav_file_path', type=str)
@click.option(
'--broadcast-id',
metavar='BROADCAST_ID',
type=int,
default=123456,
help='Broadcast ID',
)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
type=str,
help='Broadcast encryption code in hex format',
)
@click.pass_context
def broadcast(ctx, transport, broadcast_id, broadcast_code, wav_file_path):
"""Start a broadcast as a source."""
run_async(
run_broadcast(
transport=transport,
broadcast_id=broadcast_id,
broadcast_code=broadcast_code,
wav_file_path=wav_file_path,
)
)
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()