forked from auracaster/bumble_mirror
Compare commits
15 Commits
gbg/blueto
...
gbg/sdp-en
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
931e2de854 | ||
|
|
55eb7eb237 | ||
|
|
5a477eb391 | ||
|
|
86cda8771d | ||
|
|
c1ea0ddd35 | ||
|
|
f567711a6c | ||
|
|
509df4c676 | ||
|
|
b375ed07b4 | ||
|
|
69d62d3dd1 | ||
|
|
fe3fa3d505 | ||
|
|
27fcd43224 | ||
|
|
c3b2bb19d5 | ||
|
|
34287177b9 | ||
|
|
d238dd4059 | ||
|
|
7324d322fe |
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
@@ -14,9 +14,12 @@
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"auracast",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"biginfo",
|
||||
"bigs",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -36,6 +39,7 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"ediv",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
@@ -47,6 +51,7 @@
|
||||
"libc",
|
||||
"liblc",
|
||||
"libusb",
|
||||
"maxs",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
@@ -54,8 +59,10 @@
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"PDUS",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"prefs",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
|
||||
429
apps/auracast.py
429
apps/auracast.py
@@ -16,25 +16,35 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import wave
|
||||
import itertools
|
||||
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
|
||||
|
||||
import click
|
||||
import pyee
|
||||
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
|
||||
|
||||
from bumble.colors import color
|
||||
import bumble.company_ids
|
||||
import bumble.core
|
||||
from bumble import company_ids
|
||||
from bumble import core
|
||||
from bumble import gatt
|
||||
from bumble import hci
|
||||
from bumble.profiles import bap
|
||||
from bumble.profiles import le_audio
|
||||
from bumble.profiles import pbp
|
||||
from bumble.profiles import bass
|
||||
import bumble.device
|
||||
import bumble.gatt
|
||||
import bumble.hci
|
||||
import bumble.profiles.bap
|
||||
import bumble.profiles.bass
|
||||
import bumble.profiles.pbp
|
||||
import bumble.transport
|
||||
import bumble.utils
|
||||
|
||||
@@ -49,7 +59,7 @@ logger = logging.getLogger(__name__)
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5')
|
||||
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
|
||||
AURACAST_DEFAULT_ATT_MTU = 256
|
||||
|
||||
@@ -62,17 +72,12 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
class Broadcast(pyee.EventEmitter):
|
||||
name: str | None
|
||||
sync: bumble.device.PeriodicAdvertisingSync
|
||||
broadcast_id: int
|
||||
rssi: int = 0
|
||||
public_broadcast_announcement: Optional[
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement
|
||||
] = None
|
||||
broadcast_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement
|
||||
] = None
|
||||
basic_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BasicAudioAnnouncement
|
||||
] = None
|
||||
appearance: Optional[bumble.core.Appearance] = None
|
||||
public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None
|
||||
broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None
|
||||
basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None
|
||||
appearance: Optional[core.Appearance] = None
|
||||
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
|
||||
manufacturer_data: Optional[Tuple[str, bytes]] = None
|
||||
|
||||
@@ -86,42 +91,36 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
self.rssi = advertisement.rssi
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE:
|
||||
self.public_broadcast_announcement = (
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.broadcast_audio_announcement = (
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
self.appearance = advertisement.data.get( # type: ignore[assignment]
|
||||
bumble.core.AdvertisingData.APPEARANCE
|
||||
core.AdvertisingData.APPEARANCE
|
||||
)
|
||||
|
||||
if manufacturer_data := advertisement.data.get(
|
||||
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
):
|
||||
assert isinstance(manufacturer_data, tuple)
|
||||
company_id = cast(int, manufacturer_data[0])
|
||||
data = cast(bytes, manufacturer_data[1])
|
||||
self.manufacturer_data = (
|
||||
bumble.company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_id, f'0x{company_id:04X}'
|
||||
),
|
||||
data,
|
||||
@@ -232,15 +231,15 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
return
|
||||
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.basic_audio_announcement = (
|
||||
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
break
|
||||
|
||||
@@ -262,7 +261,7 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
self.device = device
|
||||
self.filter_duplicates = filter_duplicates
|
||||
self.sync_timeout = sync_timeout
|
||||
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
|
||||
self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]()
|
||||
device.on('advertisement', self.on_advertisement)
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -277,33 +276,44 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
if not (
|
||||
ads := advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
|
||||
)
|
||||
) or not (
|
||||
any(
|
||||
ad
|
||||
for ad in ads
|
||||
if isinstance(ad, tuple)
|
||||
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
broadcast_audio_announcement := next(
|
||||
(
|
||||
ad
|
||||
for ad in ads
|
||||
if isinstance(ad, tuple)
|
||||
and ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
),
|
||||
None,
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
broadcast_name = advertisement.data.get(
|
||||
bumble.core.AdvertisingData.BROADCAST_NAME
|
||||
)
|
||||
broadcast_name = advertisement.data.get(core.AdvertisingData.BROADCAST_NAME)
|
||||
assert isinstance(broadcast_name, str) or broadcast_name is None
|
||||
assert isinstance(broadcast_audio_announcement[1], bytes)
|
||||
|
||||
if broadcast := self.broadcasts.get(advertisement.address):
|
||||
broadcast.update(advertisement)
|
||||
return
|
||||
|
||||
bumble.utils.AsyncRunner.spawn(
|
||||
self.on_new_broadcast(broadcast_name, advertisement)
|
||||
self.on_new_broadcast(
|
||||
broadcast_name,
|
||||
advertisement,
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(
|
||||
broadcast_audio_announcement[1]
|
||||
).broadcast_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_new_broadcast(
|
||||
self, name: str | None, advertisement: bumble.device.Advertisement
|
||||
self,
|
||||
name: str | None,
|
||||
advertisement: bumble.device.Advertisement,
|
||||
broadcast_id: int,
|
||||
) -> None:
|
||||
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
|
||||
advertiser_address=advertisement.address,
|
||||
@@ -311,7 +321,7 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
sync_timeout=self.sync_timeout,
|
||||
filter_duplicates=self.filter_duplicates,
|
||||
)
|
||||
broadcast = self.Broadcast(name, periodic_advertising_sync)
|
||||
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
|
||||
broadcast.update(advertisement)
|
||||
self.broadcasts[advertisement.address] = broadcast
|
||||
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
||||
@@ -323,10 +333,11 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
self.emit('broadcast_loss', broadcast)
|
||||
|
||||
|
||||
class PrintingBroadcastScanner:
|
||||
class PrintingBroadcastScanner(pyee.EventEmitter):
|
||||
def __init__(
|
||||
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||
self.scanner.on('new_broadcast', self.on_new_broadcast)
|
||||
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
|
||||
@@ -461,24 +472,26 @@ async def run_assist(
|
||||
await peer.request_mtu(mtu)
|
||||
|
||||
# Get the BASS service
|
||||
bass = await peer.discover_service_and_create_proxy(
|
||||
bumble.profiles.bass.BroadcastAudioScanServiceProxy
|
||||
bass_client = await peer.discover_service_and_create_proxy(
|
||||
bass.BroadcastAudioScanServiceProxy
|
||||
)
|
||||
|
||||
# Check that the service was found
|
||||
if not bass:
|
||||
if not bass_client:
|
||||
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
|
||||
return
|
||||
|
||||
# Subscribe to and read the broadcast receive state characteristics
|
||||
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
|
||||
for i, broadcast_receive_state in enumerate(
|
||||
bass_client.broadcast_receive_states
|
||||
):
|
||||
try:
|
||||
await broadcast_receive_state.subscribe(
|
||||
lambda value, i=i: print(
|
||||
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
|
||||
)
|
||||
)
|
||||
except bumble.core.ProtocolError as error:
|
||||
except core.ProtocolError as error:
|
||||
print(
|
||||
color(
|
||||
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
|
||||
@@ -497,7 +510,7 @@ async def run_assist(
|
||||
|
||||
if command == 'add-source':
|
||||
# Find the requested broadcast
|
||||
await bass.remote_scan_started()
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
@@ -517,15 +530,15 @@ async def run_assist(
|
||||
|
||||
# Add the source
|
||||
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
|
||||
await bass.add_source(
|
||||
await bass_client.add_source(
|
||||
broadcast.sync.advertiser_address,
|
||||
broadcast.sync.sid,
|
||||
broadcast.broadcast_audio_announcement.broadcast_id,
|
||||
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bumble.profiles.bass.SubgroupInfo(
|
||||
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
@@ -535,7 +548,7 @@ async def run_assist(
|
||||
await broadcast.sync.transfer(peer.connection)
|
||||
|
||||
# Notify the sink that we're done scanning.
|
||||
await bass.remote_scan_stopped()
|
||||
await bass_client.remote_scan_stopped()
|
||||
|
||||
await peer.sustain()
|
||||
return
|
||||
@@ -546,7 +559,7 @@ async def run_assist(
|
||||
return
|
||||
|
||||
# Find the requested broadcast
|
||||
await bass.remote_scan_started()
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
@@ -569,13 +582,13 @@ async def run_assist(
|
||||
color('Modifying source:', 'blue'),
|
||||
source_id,
|
||||
)
|
||||
await bass.modify_source(
|
||||
await bass_client.modify_source(
|
||||
source_id,
|
||||
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bumble.profiles.bass.SubgroupInfo(
|
||||
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
@@ -590,7 +603,7 @@ async def run_assist(
|
||||
|
||||
# Remove the source
|
||||
print(color('Removing source:', 'blue'), source_id)
|
||||
await bass.remove_source(source_id)
|
||||
await bass_client.remove_source(source_id)
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
@@ -610,14 +623,228 @@ async def run_pair(transport: str, address: str) -> None:
|
||||
print("+++ Paired")
|
||||
|
||||
|
||||
async def run_receive(
|
||||
transport: str,
|
||||
broadcast_id: int,
|
||||
broadcast_code: str | None,
|
||||
sync_timeout: float,
|
||||
subgroup_index: int,
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
scanner = BroadcastScanner(device, False, sync_timeout)
|
||||
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
if scan_result.done():
|
||||
return
|
||||
if broadcast.broadcast_id == broadcast_id:
|
||||
scan_result.set_result(broadcast)
|
||||
|
||||
scanner.on('new_broadcast', on_new_broadcast)
|
||||
await scanner.start()
|
||||
print('Start scanning...')
|
||||
broadcast = await scan_result
|
||||
print('Advertisement found:')
|
||||
broadcast.print()
|
||||
basic_audio_announcement_scanned = asyncio.Event()
|
||||
|
||||
def on_change() -> None:
|
||||
if (
|
||||
broadcast.basic_audio_announcement
|
||||
and not basic_audio_announcement_scanned.is_set()
|
||||
):
|
||||
basic_audio_announcement_scanned.set()
|
||||
|
||||
broadcast.on('change', on_change)
|
||||
if not broadcast.basic_audio_announcement:
|
||||
print('Wait for Basic Audio Announcement...')
|
||||
await basic_audio_announcement_scanned.wait()
|
||||
print('Basic Audio Announcement found')
|
||||
broadcast.print()
|
||||
print('Stop scanning')
|
||||
await scanner.stop()
|
||||
print('Start sync to BIG')
|
||||
|
||||
assert broadcast.basic_audio_announcement
|
||||
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
|
||||
configuration = subgroup.codec_specific_configuration
|
||||
assert configuration
|
||||
assert (sampling_frequency := configuration.sampling_frequency)
|
||||
assert (frame_duration := configuration.frame_duration)
|
||||
|
||||
big_sync = await device.create_big_sync(
|
||||
broadcast.sync,
|
||||
bumble.device.BigSyncParameters(
|
||||
big_sync_timeout=0x4000,
|
||||
bis=[bis.index for bis in subgroup.bis],
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
num_bis = len(big_sync.bis_links)
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=num_bis,
|
||||
)
|
||||
sdus = [b''] * num_bis
|
||||
subprocess = await asyncio.create_subprocess_shell(
|
||||
f'stdbuf -i0 ffplay -ar {sampling_frequency.hz} -ac {num_bis} -f f32le pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
for i, bis_link in enumerate(big_sync.bis_links):
|
||||
print(f'Setup ISO for BIS {bis_link.handle}')
|
||||
|
||||
def sink(index: int, packet: hci.HCI_IsoDataPacket):
|
||||
nonlocal sdus
|
||||
sdus[index] = packet.iso_sdu_fragment
|
||||
if all(sdus) and subprocess.stdin:
|
||||
subprocess.stdin.write(decoder.decode(b''.join(sdus)).tobytes())
|
||||
sdus = [b''] * num_bis
|
||||
|
||||
bis_link.sink = functools.partial(sink, i)
|
||||
await bis_link.setup_data_path(
|
||||
direction=bis_link.Direction.CONTROLLER_TO_HOST
|
||||
)
|
||||
|
||||
terminated = asyncio.Event()
|
||||
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
|
||||
await terminated.wait()
|
||||
|
||||
|
||||
async def run_broadcast(
|
||||
transport: str, broadcast_id: int, broadcast_code: str | None, wav_file_path: str
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
with wave.open(wav_file_path, 'rb') as wav:
|
||||
print('Encoding wav file into lc3...')
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=10000,
|
||||
sample_rate_hz=48000,
|
||||
num_channels=2,
|
||||
input_sample_rate_hz=wav.getframerate(),
|
||||
)
|
||||
frames = list[bytes]()
|
||||
while pcm := wav.readframes(encoder.get_frame_samples()):
|
||||
frames.append(
|
||||
encoder.encode(pcm, num_bytes=200, bit_depth=wav.getsampwidth() * 8)
|
||||
)
|
||||
del encoder
|
||||
print('Encoding complete.')
|
||||
|
||||
basic_audio_announcement = bap.BasicAudioAnnouncement(
|
||||
presentation_delay=40000,
|
||||
subgroups=[
|
||||
bap.BasicAudioAnnouncement.Subgroup(
|
||||
codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3),
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
sampling_frequency=bap.SamplingFrequency.FREQ_48000,
|
||||
frame_duration=bap.FrameDuration.DURATION_10000_US,
|
||||
octets_per_codec_frame=100,
|
||||
),
|
||||
metadata=le_audio.Metadata(
|
||||
[
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng'
|
||||
),
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco'
|
||||
),
|
||||
]
|
||||
),
|
||||
bis=[
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=1,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_LEFT
|
||||
),
|
||||
),
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=2,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
|
||||
print('Start Advertising')
|
||||
advertising_set = await device.create_advertising_set(
|
||||
advertising_parameters=bumble.device.AdvertisingParameters(
|
||||
advertising_event_properties=bumble.device.AdvertisingEventProperties(
|
||||
is_connectable=False
|
||||
),
|
||||
primary_advertising_interval_min=100,
|
||||
primary_advertising_interval_max=200,
|
||||
),
|
||||
advertising_data=(
|
||||
broadcast_audio_announcement.get_advertising_data()
|
||||
+ bytes(
|
||||
core.AdvertisingData(
|
||||
[(core.AdvertisingData.BROADCAST_NAME, b'Bumble Auracast')]
|
||||
)
|
||||
)
|
||||
),
|
||||
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
|
||||
periodic_advertising_interval_min=80,
|
||||
periodic_advertising_interval_max=160,
|
||||
),
|
||||
periodic_advertising_data=basic_audio_announcement.get_advertising_data(),
|
||||
auto_restart=True,
|
||||
auto_start=True,
|
||||
)
|
||||
print('Start Periodic Advertising')
|
||||
await advertising_set.start_periodic()
|
||||
print('Setup BIG')
|
||||
big = await device.create_big(
|
||||
advertising_set,
|
||||
parameters=bumble.device.BigParameters(
|
||||
num_bis=2,
|
||||
sdu_interval=10000,
|
||||
max_sdu=100,
|
||||
max_transport_latency=65,
|
||||
rtn=4,
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
print('Setup ISO Data Path')
|
||||
for bis_link in big.bis_links:
|
||||
await bis_link.setup_data_path(
|
||||
direction=bis_link.Direction.HOST_TO_CONTROLLER
|
||||
)
|
||||
|
||||
for frame in itertools.cycle(frames):
|
||||
mid = len(frame) // 2
|
||||
big.bis_links[0].write(frame[:mid])
|
||||
big.bis_links[1].write(frame[mid:])
|
||||
await asyncio.sleep(0.009)
|
||||
|
||||
|
||||
def run_async(async_command: Coroutine) -> None:
|
||||
try:
|
||||
asyncio.run(async_command)
|
||||
except bumble.core.ProtocolError as error:
|
||||
except core.ProtocolError as error:
|
||||
if error.error_namespace == 'att' and error.error_code in list(
|
||||
bumble.profiles.bass.ApplicationError
|
||||
bass.ApplicationError
|
||||
):
|
||||
message = bumble.profiles.bass.ApplicationError(error.error_code).name
|
||||
message = bass.ApplicationError(error.error_code).name
|
||||
else:
|
||||
message = str(error)
|
||||
|
||||
@@ -631,9 +858,7 @@ def run_async(async_command: Coroutine) -> None:
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
def auracast(
|
||||
ctx,
|
||||
):
|
||||
def auracast(ctx):
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
|
||||
@@ -691,6 +916,66 @@ def pair(ctx, transport, address):
|
||||
run_async(run_pair(transport, address))
|
||||
|
||||
|
||||
@auracast.command('receive')
|
||||
@click.argument('transport')
|
||||
@click.argument('broadcast_id', type=int)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
type=str,
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.option(
|
||||
'--sync-timeout',
|
||||
metavar='SYNC_TIMEOUT',
|
||||
type=float,
|
||||
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
|
||||
help='Sync timeout (in seconds)',
|
||||
)
|
||||
@click.option(
|
||||
'--subgroup',
|
||||
metavar='SUBGROUP',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Index of Subgroup',
|
||||
)
|
||||
@click.pass_context
|
||||
def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout, subgroup):
|
||||
"""Receive a broadcast source"""
|
||||
run_async(
|
||||
run_receive(transport, broadcast_id, broadcast_code, sync_timeout, subgroup)
|
||||
)
|
||||
|
||||
|
||||
@auracast.command('broadcast')
|
||||
@click.argument('transport')
|
||||
@click.argument('wav_file_path', type=str)
|
||||
@click.option(
|
||||
'--broadcast-id',
|
||||
metavar='BROADCAST_ID',
|
||||
type=int,
|
||||
default=123456,
|
||||
help='Broadcast ID',
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
type=str,
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.pass_context
|
||||
def broadcast(ctx, transport, broadcast_id, broadcast_code, wav_file_path):
|
||||
"""Start a broadcast as a source."""
|
||||
run_async(
|
||||
run_broadcast(
|
||||
transport=transport,
|
||||
broadcast_id=broadcast_id,
|
||||
broadcast_code=broadcast_code,
|
||||
wav_file_path=wav_file_path,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
auracast()
|
||||
|
||||
@@ -39,7 +39,7 @@ import aiohttp.web
|
||||
import bumble
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
|
||||
from bumble.transport import open_transport
|
||||
from bumble.profiles import ascs, bap, pacs
|
||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||
@@ -110,7 +110,7 @@ async def lc3_source_task(
|
||||
sdu_length: int,
|
||||
frame_duration_us: int,
|
||||
device: Device,
|
||||
cis_handle: int,
|
||||
cis_link: CisLink,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
|
||||
@@ -120,7 +120,6 @@ async def lc3_source_task(
|
||||
)
|
||||
with wave.open(filename, 'rb') as wav:
|
||||
bits_per_sample = wav.getsampwidth() * 8
|
||||
packet_sequence_number = 0
|
||||
|
||||
encoder: lc3.Encoder | None = None
|
||||
|
||||
@@ -150,18 +149,8 @@ async def lc3_source_task(
|
||||
num_bytes=sdu_length,
|
||||
bit_depth=bits_per_sample,
|
||||
)
|
||||
cis_link.write(sdu)
|
||||
|
||||
iso_packet = HCI_IsoDataPacket(
|
||||
connection_handle=cis_handle,
|
||||
data_total_length=sdu_length + 4,
|
||||
packet_sequence_number=packet_sequence_number,
|
||||
pb_flag=0b10,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=sdu_length,
|
||||
iso_sdu_fragment=sdu,
|
||||
)
|
||||
device.host.send_hci_packet(iso_packet)
|
||||
packet_sequence_number += 1
|
||||
sleep_time = next_round - datetime.datetime.now()
|
||||
await asyncio.sleep(sleep_time.total_seconds() * 0.9)
|
||||
|
||||
@@ -309,6 +298,7 @@ class Speaker:
|
||||
advertising_interval_min=25,
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
@@ -393,7 +383,7 @@ class Speaker:
|
||||
),
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
device=self.device,
|
||||
cis_handle=ase.cis_link.handle,
|
||||
cis_link=ase.cis_link,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1501,7 +1501,10 @@ class AdvertisingData:
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
|
||||
ad_type_str = 'Complete Local Name'
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
try:
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
except UnicodeDecodeError:
|
||||
ad_data_str = ad_data.hex()
|
||||
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
|
||||
ad_type_str = 'TX Power Level'
|
||||
ad_data_str = str(ad_data[0])
|
||||
|
||||
499
bumble/device.py
499
bumble/device.py
@@ -17,7 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Sequence
|
||||
from contextlib import (
|
||||
asynccontextmanager,
|
||||
AsyncExitStack,
|
||||
@@ -37,9 +37,7 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -119,6 +117,8 @@ DEVICE_MIN_LE_RSSI = -127
|
||||
DEVICE_MAX_LE_RSSI = 20
|
||||
DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE = 0x00
|
||||
DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE = 0xEF
|
||||
DEVICE_MIN_BIG_HANDLE = 0x00
|
||||
DEVICE_MAX_BIG_HANDLE = 0xEF
|
||||
|
||||
DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00'
|
||||
DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms
|
||||
@@ -992,6 +992,130 @@ class PeriodicAdvertisingSync(EventEmitter):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class BigParameters:
|
||||
num_bis: int
|
||||
sdu_interval: int
|
||||
max_sdu: int
|
||||
max_transport_latency: int
|
||||
rtn: int
|
||||
phy: hci.PhyBit = hci.PhyBit.LE_2M
|
||||
packing: int = 0
|
||||
framing: int = 0
|
||||
broadcast_code: bytes | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class Big(EventEmitter):
|
||||
class State(IntEnum):
|
||||
PENDING = 0
|
||||
ACTIVE = 1
|
||||
TERMINATED = 2
|
||||
|
||||
class Event(str, Enum):
|
||||
ESTABLISHMENT = 'establishment'
|
||||
ESTABLISHMENT_FAILURE = 'establishment_failure'
|
||||
TERMINATION = 'termination'
|
||||
|
||||
big_handle: int
|
||||
advertising_set: AdvertisingSet
|
||||
parameters: BigParameters
|
||||
state: State = State.PENDING
|
||||
|
||||
# Attributes provided by BIG Create Complete event
|
||||
big_sync_delay: int = 0
|
||||
transport_latency_big: int = 0
|
||||
phy: int = 0
|
||||
nse: int = 0
|
||||
bn: int = 0
|
||||
pto: int = 0
|
||||
irc: int = 0
|
||||
max_pdu: int = 0
|
||||
iso_interval: int = 0
|
||||
bis_links: Sequence[BisLink] = ()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.device = self.advertising_set.device
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
) -> None:
|
||||
if self.state != Big.State.ACTIVE:
|
||||
logger.error('BIG %d is not active.', self.big_handle)
|
||||
return
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
terminated = asyncio.Event()
|
||||
watcher.once(self, Big.Event.TERMINATION, lambda _: terminated.set())
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_Terminate_BIG_Command(
|
||||
big_handle=self.big_handle, reason=reason
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
await terminated.wait()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class BigSyncParameters:
|
||||
big_sync_timeout: int
|
||||
bis: Sequence[int]
|
||||
mse: int = 0
|
||||
broadcast_code: bytes | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class BigSync(EventEmitter):
|
||||
class State(IntEnum):
|
||||
PENDING = 0
|
||||
ACTIVE = 1
|
||||
TERMINATED = 2
|
||||
|
||||
class Event(str, Enum):
|
||||
ESTABLISHMENT = 'establishment'
|
||||
ESTABLISHMENT_FAILURE = 'establishment_failure'
|
||||
TERMINATION = 'termination'
|
||||
|
||||
big_handle: int
|
||||
pa_sync: PeriodicAdvertisingSync
|
||||
parameters: BigSyncParameters
|
||||
state: State = State.PENDING
|
||||
|
||||
# Attributes provided by BIG Create Sync Complete event
|
||||
transport_latency_big: int = 0
|
||||
nse: int = 0
|
||||
bn: int = 0
|
||||
pto: int = 0
|
||||
irc: int = 0
|
||||
max_pdu: int = 0
|
||||
iso_interval: int = 0
|
||||
bis_links: Sequence[BisLink] = ()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.device = self.pa_sync.device
|
||||
|
||||
async def terminate(self) -> None:
|
||||
if self.state != BigSync.State.ACTIVE:
|
||||
logger.error('BIG Sync %d is not active.', self.big_handle)
|
||||
return
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
terminated = asyncio.Event()
|
||||
watcher.once(self, BigSync.Event.TERMINATION, lambda _: terminated.set())
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_BIG_Terminate_Sync_Command(big_handle=self.big_handle),
|
||||
check_result=True,
|
||||
)
|
||||
await terminated.wait()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LePhyOptions:
|
||||
# Coded PHY preference
|
||||
@@ -1019,7 +1143,7 @@ class Peer:
|
||||
connection.gatt_client = self.gatt_client
|
||||
|
||||
@property
|
||||
def services(self) -> List[gatt_client.ServiceProxy]:
|
||||
def services(self) -> list[gatt_client.ServiceProxy]:
|
||||
return self.gatt_client.services
|
||||
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
@@ -1029,24 +1153,24 @@ class Peer:
|
||||
|
||||
async def discover_service(
|
||||
self, uuid: Union[core.UUID, str]
|
||||
) -> List[gatt_client.ServiceProxy]:
|
||||
) -> list[gatt_client.ServiceProxy]:
|
||||
return await self.gatt_client.discover_service(uuid)
|
||||
|
||||
async def discover_services(
|
||||
self, uuids: Iterable[core.UUID] = ()
|
||||
) -> List[gatt_client.ServiceProxy]:
|
||||
) -> list[gatt_client.ServiceProxy]:
|
||||
return await self.gatt_client.discover_services(uuids)
|
||||
|
||||
async def discover_included_services(
|
||||
self, service: gatt_client.ServiceProxy
|
||||
) -> List[gatt_client.ServiceProxy]:
|
||||
) -> list[gatt_client.ServiceProxy]:
|
||||
return await self.gatt_client.discover_included_services(service)
|
||||
|
||||
async def discover_characteristics(
|
||||
self,
|
||||
uuids: Iterable[Union[core.UUID, str]] = (),
|
||||
service: Optional[gatt_client.ServiceProxy] = None,
|
||||
) -> List[gatt_client.CharacteristicProxy]:
|
||||
) -> list[gatt_client.CharacteristicProxy]:
|
||||
return await self.gatt_client.discover_characteristics(
|
||||
uuids=uuids, service=service
|
||||
)
|
||||
@@ -1061,7 +1185,7 @@ class Peer:
|
||||
characteristic, start_handle, end_handle
|
||||
)
|
||||
|
||||
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
|
||||
async def discover_attributes(self) -> list[gatt_client.AttributeProxy]:
|
||||
return await self.gatt_client.discover_attributes()
|
||||
|
||||
async def discover_all(self):
|
||||
@@ -1105,17 +1229,17 @@ class Peer:
|
||||
|
||||
async def read_characteristics_by_uuid(
|
||||
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
|
||||
) -> List[bytes]:
|
||||
) -> list[bytes]:
|
||||
return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
|
||||
|
||||
def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
|
||||
def get_services_by_uuid(self, uuid: core.UUID) -> list[gatt_client.ServiceProxy]:
|
||||
return self.gatt_client.get_services_by_uuid(uuid)
|
||||
|
||||
def get_characteristics_by_uuid(
|
||||
self,
|
||||
uuid: core.UUID,
|
||||
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
|
||||
) -> List[gatt_client.CharacteristicProxy]:
|
||||
) -> list[gatt_client.CharacteristicProxy]:
|
||||
if isinstance(service, core.UUID):
|
||||
return list(
|
||||
itertools.chain(
|
||||
@@ -1201,9 +1325,93 @@ class ScoLink(CompositeEventEmitter):
|
||||
await self.device.disconnect(self, reason)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class _IsoLink:
|
||||
handle: int
|
||||
device: Device
|
||||
packet_sequence_number: int
|
||||
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
|
||||
|
||||
class Direction(IntEnum):
|
||||
HOST_TO_CONTROLLER = (
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||
)
|
||||
CONTROLLER_TO_HOST = (
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
)
|
||||
|
||||
async def setup_data_path(
|
||||
self,
|
||||
direction: _IsoLink.Direction,
|
||||
data_path_id: int = 0,
|
||||
codec_id: hci.CodingFormat | None = None,
|
||||
controller_delay: int = 0,
|
||||
codec_configuration: bytes = b'',
|
||||
) -> None:
|
||||
"""Create a data path between controller and given entry.
|
||||
|
||||
Args:
|
||||
direction: Direction of data path.
|
||||
data_path_id: ID of data path. Default is 0 (HCI).
|
||||
codec_id: Codec ID. Default is Transparent.
|
||||
controller_delay: Controller delay in microseconds. Default is 0.
|
||||
codec_configuration: Codec-specific configuration.
|
||||
|
||||
Raises:
|
||||
HCI_Error: When command complete status is not HCI_SUCCESS.
|
||||
"""
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=direction,
|
||||
data_path_id=data_path_id,
|
||||
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=controller_delay,
|
||||
codec_configuration=codec_configuration,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
async def remove_data_path(self, direction: _IsoLink.Direction) -> int:
|
||||
"""Remove a data path with controller on given direction.
|
||||
|
||||
Args:
|
||||
direction: Direction of data path.
|
||||
|
||||
Returns:
|
||||
Command status.
|
||||
"""
|
||||
response = await self.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=direction,
|
||||
),
|
||||
check_result=False,
|
||||
)
|
||||
return response.return_parameters.status
|
||||
|
||||
def write(self, sdu: bytes) -> None:
|
||||
"""Write an ISO SDU.
|
||||
|
||||
This will automatically increase the packet sequence number.
|
||||
"""
|
||||
self.device.host.send_hci_packet(
|
||||
hci.HCI_IsoDataPacket(
|
||||
connection_handle=self.handle,
|
||||
data_total_length=len(sdu) + 4,
|
||||
packet_sequence_number=self.packet_sequence_number,
|
||||
pb_flag=0b10,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=len(sdu),
|
||||
iso_sdu_fragment=sdu,
|
||||
)
|
||||
)
|
||||
self.packet_sequence_number = (self.packet_sequence_number + 1) % 0x10000
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class CisLink(CompositeEventEmitter):
|
||||
class CisLink(CompositeEventEmitter, _IsoLink):
|
||||
class State(IntEnum):
|
||||
PENDING = 0
|
||||
ESTABLISHED = 1
|
||||
@@ -1214,10 +1422,11 @@ class CisLink(CompositeEventEmitter):
|
||||
cis_id: int # CIS ID assigned by Central device
|
||||
cig_id: int # CIG ID assigned by Central device
|
||||
state: State = State.PENDING
|
||||
sink: Optional[Callable[[hci.HCI_IsoDataPacket], Any]] = None
|
||||
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.packet_sequence_number = 0
|
||||
|
||||
async def disconnect(
|
||||
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
|
||||
@@ -1225,6 +1434,18 @@ class CisLink(CompositeEventEmitter):
|
||||
await self.device.disconnect(self, reason)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class BisLink(_IsoLink):
|
||||
handle: int
|
||||
big: Big | BigSync
|
||||
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.device = self.big.device
|
||||
self.packet_sequence_number = 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Connection(CompositeEventEmitter):
|
||||
device: Device
|
||||
@@ -1537,7 +1758,7 @@ class DeviceConfiguration:
|
||||
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.gatt_services: List[Dict[str, Any]] = []
|
||||
self.gatt_services: list[Dict[str, Any]] = []
|
||||
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
config = copy.deepcopy(config)
|
||||
@@ -1684,7 +1905,7 @@ def host_event_handler(function):
|
||||
# List of host event handlers for the Device class.
|
||||
# (we define this list outside the class, because referencing a class in method
|
||||
# decorators is not straightforward)
|
||||
device_host_event_handlers: List[str] = []
|
||||
device_host_event_handlers: list[str] = []
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1705,15 +1926,18 @@ class Device(CompositeEventEmitter):
|
||||
pending_connections: Dict[hci.Address, Connection]
|
||||
classic_pending_accepts: Dict[
|
||||
hci.Address,
|
||||
List[asyncio.Future[Union[Connection, Tuple[hci.Address, int, int]]]],
|
||||
list[asyncio.Future[Union[Connection, tuple[hci.Address, int, int]]]],
|
||||
]
|
||||
advertisement_accumulators: Dict[hci.Address, AdvertisementDataAccumulator]
|
||||
periodic_advertising_syncs: List[PeriodicAdvertisingSync]
|
||||
periodic_advertising_syncs: list[PeriodicAdvertisingSync]
|
||||
config: DeviceConfiguration
|
||||
legacy_advertiser: Optional[LegacyAdvertiser]
|
||||
sco_links: Dict[int, ScoLink]
|
||||
cis_links: Dict[int, CisLink]
|
||||
_pending_cis: Dict[int, Tuple[int, int]]
|
||||
bigs = dict[int, Big]()
|
||||
bis_links = dict[int, BisLink]()
|
||||
big_syncs = dict[int, BigSync]()
|
||||
_pending_cis: Dict[int, tuple[int, int]]
|
||||
|
||||
@composite_listener
|
||||
class Listener:
|
||||
@@ -2009,6 +2233,17 @@ class Device(CompositeEventEmitter):
|
||||
None,
|
||||
)
|
||||
|
||||
def next_big_handle(self) -> int | None:
|
||||
return next(
|
||||
(
|
||||
handle
|
||||
for handle in range(DEVICE_MIN_BIG_HANDLE, DEVICE_MAX_BIG_HANDLE + 1)
|
||||
if handle
|
||||
not in itertools.chain(self.bigs.keys(), self.big_syncs.keys())
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@deprecated("Please use create_l2cap_server()")
|
||||
def register_l2cap_server(self, psm, server) -> int:
|
||||
return self.l2cap_channel_manager.register_server(psm, server)
|
||||
@@ -2627,7 +2862,7 @@ class Device(CompositeEventEmitter):
|
||||
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
|
||||
own_address_type: int = hci.OwnAddressType.RANDOM,
|
||||
filter_duplicates: bool = False,
|
||||
scanning_phys: List[int] = [hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY],
|
||||
scanning_phys: Sequence[int] = (hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY),
|
||||
) -> None:
|
||||
# Check that the arguments are legal
|
||||
if scan_interval < scan_window:
|
||||
@@ -2674,7 +2909,7 @@ class Device(CompositeEventEmitter):
|
||||
scanning_filter_policy=scanning_filter_policy,
|
||||
scanning_phys=scanning_phys_bits,
|
||||
scan_types=[scan_type] * scanning_phy_count,
|
||||
scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
scan_intervals=[int(scan_interval / 0.625)] * scanning_phy_count,
|
||||
scan_windows=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
),
|
||||
check_result=True,
|
||||
@@ -3958,13 +4193,13 @@ class Device(CompositeEventEmitter):
|
||||
async def setup_cig(
|
||||
self,
|
||||
cig_id: int,
|
||||
cis_id: List[int],
|
||||
sdu_interval: Tuple[int, int],
|
||||
cis_id: Sequence[int],
|
||||
sdu_interval: tuple[int, int],
|
||||
framing: int,
|
||||
max_sdu: Tuple[int, int],
|
||||
max_sdu: tuple[int, int],
|
||||
retransmission_number: int,
|
||||
max_transport_latency: Tuple[int, int],
|
||||
) -> List[int]:
|
||||
max_transport_latency: tuple[int, int],
|
||||
) -> list[int]:
|
||||
"""Sends hci.HCI_LE_Set_CIG_Parameters_Command.
|
||||
|
||||
Args:
|
||||
@@ -4013,7 +4248,9 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink]:
|
||||
async def create_cis(
|
||||
self, cis_acl_pairs: Sequence[tuple[int, int]]
|
||||
) -> list[CisLink]:
|
||||
for cis_handle, acl_handle in cis_acl_pairs:
|
||||
acl_connection = self.lookup_connection(acl_handle)
|
||||
assert acl_connection
|
||||
@@ -4112,6 +4349,106 @@ class Device(CompositeEventEmitter):
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
async def create_big(
|
||||
self, advertising_set: AdvertisingSet, parameters: BigParameters
|
||||
) -> Big:
|
||||
if (big_handle := self.next_big_handle()) is None:
|
||||
raise core.OutOfResourcesError("All valid BIG handles already in use")
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
big = Big(
|
||||
big_handle=big_handle,
|
||||
parameters=parameters,
|
||||
advertising_set=advertising_set,
|
||||
)
|
||||
self.bigs[big_handle] = big
|
||||
established = asyncio.get_running_loop().create_future()
|
||||
watcher.once(
|
||||
big, big.Event.ESTABLISHMENT, lambda: established.set_result(None)
|
||||
)
|
||||
watcher.once(
|
||||
big,
|
||||
big.Event.ESTABLISHMENT_FAILURE,
|
||||
lambda status: established.set_exception(hci.HCI_Error(status)),
|
||||
)
|
||||
|
||||
try:
|
||||
await self.send_command(
|
||||
hci.HCI_LE_Create_BIG_Command(
|
||||
big_handle=big_handle,
|
||||
advertising_handle=advertising_set.advertising_handle,
|
||||
num_bis=parameters.num_bis,
|
||||
sdu_interval=parameters.sdu_interval,
|
||||
max_sdu=parameters.max_sdu,
|
||||
max_transport_latency=parameters.max_transport_latency,
|
||||
rtn=parameters.rtn,
|
||||
phy=parameters.phy,
|
||||
packing=parameters.packing,
|
||||
framing=parameters.framing,
|
||||
encryption=1 if parameters.broadcast_code else 0,
|
||||
broadcast_code=parameters.broadcast_code or bytes(16),
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
await established
|
||||
except hci.HCI_Error:
|
||||
del self.bigs[big_handle]
|
||||
raise
|
||||
|
||||
return big
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
async def create_big_sync(
|
||||
self, pa_sync: PeriodicAdvertisingSync, parameters: BigSyncParameters
|
||||
) -> BigSync:
|
||||
if (big_handle := self.next_big_handle()) is None:
|
||||
raise core.OutOfResourcesError("All valid BIG handles already in use")
|
||||
|
||||
if (pa_sync_handle := pa_sync.sync_handle) is None:
|
||||
raise core.InvalidStateError("PA Sync is not established")
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
big_sync = BigSync(
|
||||
big_handle=big_handle,
|
||||
parameters=parameters,
|
||||
pa_sync=pa_sync,
|
||||
)
|
||||
self.big_syncs[big_handle] = big_sync
|
||||
established = asyncio.get_running_loop().create_future()
|
||||
watcher.once(
|
||||
big_sync,
|
||||
big_sync.Event.ESTABLISHMENT,
|
||||
lambda: established.set_result(None),
|
||||
)
|
||||
watcher.once(
|
||||
big_sync,
|
||||
big_sync.Event.ESTABLISHMENT_FAILURE,
|
||||
lambda status: established.set_exception(hci.HCI_Error(status)),
|
||||
)
|
||||
|
||||
try:
|
||||
await self.send_command(
|
||||
hci.HCI_LE_BIG_Create_Sync_Command(
|
||||
big_handle=big_handle,
|
||||
sync_handle=pa_sync_handle,
|
||||
encryption=1 if parameters.broadcast_code else 0,
|
||||
broadcast_code=parameters.broadcast_code or bytes(16),
|
||||
mse=parameters.mse,
|
||||
big_sync_timeout=parameters.big_sync_timeout,
|
||||
bis=parameters.bis,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
await established
|
||||
except hci.HCI_Error:
|
||||
del self.big_syncs[big_handle]
|
||||
raise
|
||||
|
||||
return big_sync
|
||||
|
||||
async def get_remote_le_features(self, connection: Connection) -> hci.LeFeatureMask:
|
||||
"""[LE Only] Reads remote LE supported features.
|
||||
|
||||
@@ -4233,6 +4570,112 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
self.connecting_extended_advertising_sets[connection_handle] = advertising_set
|
||||
|
||||
@host_event_handler
|
||||
def on_big_establishment(
|
||||
self,
|
||||
status: int,
|
||||
big_handle: int,
|
||||
bis_handles: list[int],
|
||||
big_sync_delay: int,
|
||||
transport_latency_big: int,
|
||||
phy: int,
|
||||
nse: int,
|
||||
bn: int,
|
||||
pto: int,
|
||||
irc: int,
|
||||
max_pdu: int,
|
||||
iso_interval: int,
|
||||
) -> None:
|
||||
if not (big := self.bigs.get(big_handle)):
|
||||
logger.warning('BIG %d not found', big_handle)
|
||||
return
|
||||
|
||||
if status != hci.HCI_SUCCESS:
|
||||
del self.bigs[big_handle]
|
||||
logger.debug('Unable to create BIG %d', big_handle)
|
||||
big.state = Big.State.TERMINATED
|
||||
big.emit(Big.Event.ESTABLISHMENT_FAILURE, status)
|
||||
return
|
||||
|
||||
big.bis_links = [BisLink(handle=handle, big=big) for handle in bis_handles]
|
||||
big.big_sync_delay = big_sync_delay
|
||||
big.transport_latency_big = transport_latency_big
|
||||
big.phy = phy
|
||||
big.nse = nse
|
||||
big.bn = bn
|
||||
big.pto = pto
|
||||
big.irc = irc
|
||||
big.max_pdu = max_pdu
|
||||
big.iso_interval = iso_interval
|
||||
big.state = Big.State.ACTIVE
|
||||
|
||||
for bis_link in big.bis_links:
|
||||
self.bis_links[bis_link.handle] = bis_link
|
||||
big.emit(Big.Event.ESTABLISHMENT)
|
||||
|
||||
@host_event_handler
|
||||
def on_big_termination(self, reason: int, big_handle: int) -> None:
|
||||
if not (big := self.bigs.pop(big_handle, None)):
|
||||
logger.warning('BIG %d not found', big_handle)
|
||||
return
|
||||
|
||||
big.state = Big.State.TERMINATED
|
||||
for bis_link in big.bis_links:
|
||||
self.bis_links.pop(bis_link.handle, None)
|
||||
big.emit(Big.Event.TERMINATION, reason)
|
||||
|
||||
@host_event_handler
|
||||
def on_big_sync_establishment(
|
||||
self,
|
||||
status: int,
|
||||
big_handle: int,
|
||||
transport_latency_big: int,
|
||||
nse: int,
|
||||
bn: int,
|
||||
pto: int,
|
||||
irc: int,
|
||||
max_pdu: int,
|
||||
iso_interval: int,
|
||||
bis_handles: list[int],
|
||||
) -> None:
|
||||
if not (big_sync := self.big_syncs.get(big_handle)):
|
||||
logger.warning('BIG Sync %d not found', big_handle)
|
||||
return
|
||||
|
||||
if status != hci.HCI_SUCCESS:
|
||||
del self.big_syncs[big_handle]
|
||||
logger.debug('Unable to create BIG Sync %d', big_handle)
|
||||
big_sync.state = BigSync.State.TERMINATED
|
||||
big_sync.emit(BigSync.Event.ESTABLISHMENT_FAILURE, status)
|
||||
return
|
||||
|
||||
big_sync.transport_latency_big = transport_latency_big
|
||||
big_sync.nse = nse
|
||||
big_sync.bn = bn
|
||||
big_sync.pto = pto
|
||||
big_sync.irc = irc
|
||||
big_sync.max_pdu = max_pdu
|
||||
big_sync.iso_interval = iso_interval
|
||||
big_sync.bis_links = [
|
||||
BisLink(handle=handle, big=big_sync) for handle in bis_handles
|
||||
]
|
||||
big_sync.state = BigSync.State.ACTIVE
|
||||
|
||||
for bis_link in big_sync.bis_links:
|
||||
self.bis_links[bis_link.handle] = bis_link
|
||||
big_sync.emit(BigSync.Event.ESTABLISHMENT)
|
||||
|
||||
@host_event_handler
|
||||
def on_big_sync_lost(self, big_handle: int, reason: int) -> None:
|
||||
if not (big_sync := self.big_syncs.pop(big_handle, None)):
|
||||
logger.warning('BIG %d not found', big_handle)
|
||||
return
|
||||
|
||||
for bis_link in big_sync.bis_links:
|
||||
self.bis_links.pop(bis_link.handle, None)
|
||||
big_sync.state = BigSync.State.TERMINATED
|
||||
big_sync.emit(BigSync.Event.TERMINATION, reason)
|
||||
|
||||
def _complete_le_extended_advertising_connection(
|
||||
self, connection: Connection, advertising_set: AdvertisingSet
|
||||
) -> None:
|
||||
@@ -4879,6 +5322,8 @@ class Device(CompositeEventEmitter):
|
||||
def on_iso_packet(self, handle: int, packet: hci.HCI_IsoDataPacket) -> None:
|
||||
if (cis_link := self.cis_links.get(handle)) and cis_link.sink:
|
||||
cis_link.sink(packet)
|
||||
elif (bis_link := self.bis_links.get(handle)) and bis_link.sink:
|
||||
bis_link.sink(packet)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -4936,7 +4936,7 @@ class HCI_LE_Create_BIG_Command(HCI_Command):
|
||||
packing: int
|
||||
framing: int
|
||||
encryption: int
|
||||
broadcast_code: int
|
||||
broadcast_code: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -5070,7 +5070,7 @@ class HCI_Event(HCI_Packet):
|
||||
hci_packet_type = HCI_EVENT_PACKET
|
||||
event_names: Dict[int, str] = {}
|
||||
event_classes: Dict[int, Type[HCI_Event]] = {}
|
||||
vendor_factory: Optional[Callable[[bytes], Optional[HCI_Event]]] = None
|
||||
vendor_factories: list[Callable[[bytes], Optional[HCI_Event]]] = []
|
||||
|
||||
@staticmethod
|
||||
def event(fields=()):
|
||||
@@ -5128,6 +5128,19 @@ class HCI_Event(HCI_Packet):
|
||||
|
||||
return event_class
|
||||
|
||||
@classmethod
|
||||
def add_vendor_factory(
|
||||
cls, factory: Callable[[bytes], Optional[HCI_Event]]
|
||||
) -> None:
|
||||
cls.vendor_factories.append(factory)
|
||||
|
||||
@classmethod
|
||||
def remove_vendor_factory(
|
||||
cls, factory: Callable[[bytes], Optional[HCI_Event]]
|
||||
) -> None:
|
||||
if factory in cls.vendor_factories:
|
||||
cls.vendor_factories.remove(factory)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, packet: bytes) -> HCI_Event:
|
||||
event_code = packet[1]
|
||||
@@ -5148,13 +5161,13 @@ class HCI_Event(HCI_Packet):
|
||||
elif event_code == HCI_VENDOR_EVENT:
|
||||
# Invoke all the registered factories to see if any of them can handle
|
||||
# the event
|
||||
if cls.vendor_factory:
|
||||
if event := cls.vendor_factory(parameters):
|
||||
for vendor_factory in cls.vendor_factories:
|
||||
if event := vendor_factory(parameters):
|
||||
return event
|
||||
|
||||
# No factory, or the factory could not create an instance,
|
||||
# return a generic vendor event
|
||||
return HCI_Event(event_code, parameters)
|
||||
return HCI_Vendor_Event(data=parameters)
|
||||
else:
|
||||
subclass = HCI_Event.event_classes.get(event_code)
|
||||
if subclass is None:
|
||||
@@ -5825,7 +5838,7 @@ class HCI_LE_Periodic_Advertising_Sync_Lost_Event(HCI_LE_Meta_Event):
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
('status', 1),
|
||||
('status', STATUS_SPEC),
|
||||
('advertising_handle', 1),
|
||||
('connection_handle', 2),
|
||||
('num_completed_extended_advertising_events', 1),
|
||||
@@ -5908,6 +5921,70 @@ class HCI_LE_CIS_Request_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
('status', STATUS_SPEC),
|
||||
('big_handle', 1),
|
||||
('big_sync_delay', 3),
|
||||
('transport_latency_big', 3),
|
||||
('phy', 1),
|
||||
('nse', 1),
|
||||
('bn', 1),
|
||||
('pto', 1),
|
||||
('irc', 1),
|
||||
('max_pdu', 2),
|
||||
('iso_interval', 2),
|
||||
[('connection_handle', 2)],
|
||||
]
|
||||
)
|
||||
class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.27 LE Create BIG Complete Event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event([('big_handle', 1), ('reason', 1)])
|
||||
class HCI_LE_Terminate_BIG_Complete_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.28 LE Terminate BIG Complete Event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
('status', STATUS_SPEC),
|
||||
('big_handle', 1),
|
||||
('transport_latency_big', 3),
|
||||
('nse', 1),
|
||||
('bn', 1),
|
||||
('pto', 1),
|
||||
('irc', 1),
|
||||
('max_pdu', 2),
|
||||
('iso_interval', 2),
|
||||
[('connection_handle', 2)],
|
||||
]
|
||||
)
|
||||
class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.29 LE BIG Sync Established event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@HCI_LE_Meta_Event.event([('big_handle', 1), ('reason', 1)])
|
||||
class HCI_LE_BIG_Sync_Lost_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.30 LE BIG Sync Lost event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
|
||||
@@ -21,6 +21,7 @@ import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
import itertools
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -149,6 +150,7 @@ class Host(AbortableEventEmitter):
|
||||
connections: Dict[int, Connection]
|
||||
cis_links: Dict[int, CisLink]
|
||||
sco_links: Dict[int, ScoLink]
|
||||
bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles
|
||||
acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
le_acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
hci_sink: Optional[TransportSink] = None
|
||||
@@ -733,9 +735,10 @@ class Host(AbortableEventEmitter):
|
||||
):
|
||||
if connection := self.connections.get(connection_handle):
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
elif not (
|
||||
self.cis_links.get(connection_handle)
|
||||
or self.sco_links.get(connection_handle)
|
||||
elif connection_handle not in itertools.chain(
|
||||
self.cis_links.keys(),
|
||||
self.sco_links.keys(),
|
||||
itertools.chain.from_iterable(self.bigs.values()),
|
||||
):
|
||||
logger.warning(
|
||||
'received packet completion event for unknown handle '
|
||||
@@ -953,6 +956,50 @@ class Host(AbortableEventEmitter):
|
||||
event.cis_id,
|
||||
)
|
||||
|
||||
def on_hci_le_create_big_complete_event(self, event):
|
||||
self.bigs[event.big_handle] = set(event.connection_handle)
|
||||
self.emit(
|
||||
'big_establishment',
|
||||
event.status,
|
||||
event.big_handle,
|
||||
event.connection_handle,
|
||||
event.big_sync_delay,
|
||||
event.transport_latency_big,
|
||||
event.phy,
|
||||
event.nse,
|
||||
event.bn,
|
||||
event.pto,
|
||||
event.irc,
|
||||
event.max_pdu,
|
||||
event.iso_interval,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_established_event(self, event):
|
||||
self.emit(
|
||||
'big_sync_establishment',
|
||||
event.status,
|
||||
event.big_handle,
|
||||
event.transport_latency_big,
|
||||
event.nse,
|
||||
event.bn,
|
||||
event.pto,
|
||||
event.irc,
|
||||
event.max_pdu,
|
||||
event.iso_interval,
|
||||
event.connection_handle,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_lost_event(self, event):
|
||||
self.emit(
|
||||
'big_sync_lost',
|
||||
event.big_handle,
|
||||
event.reason,
|
||||
)
|
||||
|
||||
def on_hci_le_terminate_big_complete_event(self, event):
|
||||
self.bigs.pop(event.big_handle)
|
||||
self.emit('big_termination', event.reason, event.big_handle)
|
||||
|
||||
def on_hci_le_cis_established_event(self, event):
|
||||
# The remaining parameters are unused for now.
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
|
||||
@@ -773,7 +773,6 @@ class ClassicChannel(EventEmitter):
|
||||
self.psm = psm
|
||||
self.source_cid = source_cid
|
||||
self.destination_cid = 0
|
||||
self.response = None
|
||||
self.connection_result = None
|
||||
self.disconnection_result = None
|
||||
self.sink = None
|
||||
@@ -783,27 +782,15 @@ class ClassicChannel(EventEmitter):
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
|
||||
async def send_request(self, request: SupportsBytes) -> bytes:
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
|
||||
self.response = asyncio.get_running_loop().create_future()
|
||||
self.send_pdu(request)
|
||||
return await self.response
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.response:
|
||||
self.response.set_result(pdu)
|
||||
self.response = None
|
||||
elif self.sink:
|
||||
if self.sink:
|
||||
# pylint: disable=not-callable
|
||||
self.sink(pdu)
|
||||
else:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
@@ -258,8 +259,8 @@ class AseReasonCode(enum.IntEnum):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioRole(enum.IntEnum):
|
||||
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||
SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -354,16 +355,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||
|
||||
async def post_cis_established():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
data_path_id=0x00, # Fixed HCI
|
||||
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=0,
|
||||
codec_configuration=b'',
|
||||
)
|
||||
)
|
||||
await cis_link.setup_data_path(direction=self.role)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.STREAMING
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
@@ -511,12 +503,8 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.state = self.State.RELEASING
|
||||
|
||||
async def remove_cis_async():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
)
|
||||
)
|
||||
if self.cis_link:
|
||||
await self.cis_link.remove_data_path(self.role)
|
||||
self.state = self.State.IDLE
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
|
||||
314
bumble/sdp.py
314
bumble/sdp.py
@@ -16,15 +16,21 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||
from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
from bumble import core, l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
InvalidStateError,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device import Device, Connection
|
||||
@@ -242,11 +248,11 @@ class DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value: List[DataElement]) -> DataElement:
|
||||
def sequence(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value: List[DataElement]) -> DataElement:
|
||||
def alternative(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -473,7 +479,9 @@ class ServiceAttribute:
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
|
||||
def list_from_data_elements(
|
||||
elements: Sequence[DataElement],
|
||||
) -> list[ServiceAttribute]:
|
||||
attribute_list = []
|
||||
for i in range(0, len(elements) // 2):
|
||||
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
||||
@@ -486,7 +494,7 @@ class ServiceAttribute:
|
||||
|
||||
@staticmethod
|
||||
def find_attribute_in_list(
|
||||
attribute_list: List[ServiceAttribute], attribute_id: int
|
||||
attribute_list: Iterable[ServiceAttribute], attribute_id: int
|
||||
) -> Optional[DataElement]:
|
||||
return next(
|
||||
(
|
||||
@@ -534,7 +542,12 @@ class SDP_PDU:
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
|
||||
'''
|
||||
|
||||
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
|
||||
RESPONSE_PDU_IDS = {
|
||||
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
|
||||
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
|
||||
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
|
||||
}
|
||||
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
|
||||
name = None
|
||||
pdu_id = 0
|
||||
|
||||
@@ -558,7 +571,7 @@ class SDP_PDU:
|
||||
@staticmethod
|
||||
def parse_service_record_handle_list_preceded_by_count(
|
||||
data: bytes, offset: int
|
||||
) -> Tuple[int, List[int]]:
|
||||
) -> tuple[int, list[int]]:
|
||||
count = struct.unpack_from('>H', data, offset - 2)[0]
|
||||
handle_list = [
|
||||
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
||||
@@ -639,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
|
||||
'''
|
||||
|
||||
error_code: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -675,7 +690,7 @@ class SDP_ServiceSearchResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
||||
'''
|
||||
|
||||
service_record_handle_list: List[int]
|
||||
service_record_handle_list: list[int]
|
||||
total_service_record_count: int
|
||||
current_service_record_count: int
|
||||
continuation_state: bytes
|
||||
@@ -752,31 +767,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_list_byte_count: int
|
||||
attribute_list: bytes
|
||||
attribute_lists_byte_count: int
|
||||
attribute_lists: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
def __init__(self, connection: Connection, mtu: int = 0) -> None:
|
||||
self.connection = connection
|
||||
self.pending_request = None
|
||||
self.channel = None
|
||||
self.channel: Optional[l2cap.ClassicChannel] = None
|
||||
self.mtu = mtu
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request: Optional[SDP_PDU] = None
|
||||
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
|
||||
self.next_transaction_id = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.channel = await self.connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
spec=(
|
||||
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
|
||||
if self.mtu
|
||||
else l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
)
|
||||
)
|
||||
self.channel.sink = self.on_pdu
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.channel:
|
||||
await self.channel.disconnect()
|
||||
self.channel = None
|
||||
|
||||
async def search_services(self, uuids: List[core.UUID]) -> List[int]:
|
||||
def make_transaction_id(self) -> int:
|
||||
transaction_id = self.next_transaction_id
|
||||
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
|
||||
return transaction_id
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if not self.pending_request:
|
||||
logger.warning('received response with no pending request')
|
||||
return
|
||||
assert self.pending_response is not None
|
||||
|
||||
response = SDP_PDU.from_bytes(pdu)
|
||||
|
||||
# Check that the transaction ID is what we expect
|
||||
if self.pending_request.transaction_id != response.transaction_id:
|
||||
logger.warning(
|
||||
f"received response with transaction ID {response.transaction_id} "
|
||||
f"but expected {self.pending_request.transaction_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if the response is an error
|
||||
if isinstance(response, SDP_ErrorResponse):
|
||||
self.pending_response.set_exception(
|
||||
ProtocolError(error_code=response.error_code)
|
||||
)
|
||||
return
|
||||
|
||||
# Check that the type of the response matches the request
|
||||
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
|
||||
logger.warning("response type mismatch")
|
||||
return
|
||||
|
||||
self.pending_response.set_result(response)
|
||||
|
||||
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
|
||||
assert self.channel is not None
|
||||
async with self.request_semaphore:
|
||||
assert self.pending_request is None
|
||||
assert self.pending_response is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_response = asyncio.get_running_loop().create_future()
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.channel.send_pdu(bytes(request))
|
||||
return await self.pending_response
|
||||
finally:
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
|
||||
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
|
||||
"""
|
||||
Search for services by UUID.
|
||||
|
||||
Args:
|
||||
uuids: service the UUIDs to search for.
|
||||
|
||||
Returns:
|
||||
A list of matching service record handles.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -791,16 +874,16 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_service_record_count=0xFFFF,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchResponse)
|
||||
service_record_handle_list += response.service_record_handle_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -811,8 +894,21 @@ class Client:
|
||||
return service_record_handle_list
|
||||
|
||||
async def search_attributes(
|
||||
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
|
||||
) -> List[List[ServiceAttribute]]:
|
||||
self,
|
||||
uuids: Iterable[core.UUID],
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[list[ServiceAttribute]]:
|
||||
"""
|
||||
Search for attributes by UUID and attribute IDs.
|
||||
|
||||
Args:
|
||||
uuids: the service UUIDs to search for.
|
||||
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
|
||||
(use (0, 0xFFFF) to include all attributes)
|
||||
|
||||
Returns:
|
||||
A list of list of attributes, one list per matching service.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -824,8 +920,8 @@ class Client:
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
@@ -839,17 +935,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
|
||||
accumulator += response.attribute_lists
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -872,8 +968,18 @@ class Client:
|
||||
async def get_attributes(
|
||||
self,
|
||||
service_record_handle: int,
|
||||
attribute_ids: List[Union[int, Tuple[int, int]]],
|
||||
) -> List[ServiceAttribute]:
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[ServiceAttribute]:
|
||||
"""
|
||||
Get attributes for a service.
|
||||
|
||||
Args:
|
||||
service_record_handle: the handle for a service
|
||||
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
|
||||
|
||||
Returns:
|
||||
A list of attributes.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -882,8 +988,8 @@ class Client:
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
@@ -897,17 +1003,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_record_handle=service_record_handle,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceAttributeResponse)
|
||||
accumulator += response.attribute_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -933,17 +1039,17 @@ class Client:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
CONTINUATION_STATE = bytes([0x01, 0x00])
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', List[ServiceAttribute])
|
||||
service_records: Dict[int, Service]
|
||||
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||
Service = NewType('Service', list[ServiceAttribute])
|
||||
service_records: dict[int, Service]
|
||||
current_response: Union[None, bytes, tuple[int, list[int]]]
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
self.service_records = {} # Service records maps, by record handle
|
||||
self.channel = None
|
||||
self.current_response = None
|
||||
self.current_response = None # Current response data, used for continuations
|
||||
|
||||
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||
l2cap_channel_manager.create_classic_server(
|
||||
@@ -954,7 +1060,7 @@ class Server:
|
||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||
self.channel.send_pdu(response)
|
||||
|
||||
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
|
||||
def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
|
||||
# Find the services for which the attributes in the pattern is a subset of the
|
||||
# service's attribute values (NOTE: the value search recurses into sequences)
|
||||
matching_services = {}
|
||||
@@ -1011,6 +1117,31 @@ class Server:
|
||||
)
|
||||
)
|
||||
|
||||
def check_continuation(
|
||||
self,
|
||||
continuation_state: bytes,
|
||||
transaction_id: int,
|
||||
) -> Optional[bool]:
|
||||
# Check if this is a valid continuation
|
||||
if len(continuation_state) > 1:
|
||||
if (
|
||||
self.current_response is None
|
||||
or continuation_state != self.CONTINUATION_STATE
|
||||
):
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return None
|
||||
return True
|
||||
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
|
||||
return False
|
||||
|
||||
def get_next_response_payload(self, maximum_size):
|
||||
if len(self.current_response) > maximum_size:
|
||||
payload = self.current_response[:maximum_size]
|
||||
@@ -1025,7 +1156,7 @@ class Server:
|
||||
|
||||
@staticmethod
|
||||
def get_service_attributes(
|
||||
service: Service, attribute_ids: List[DataElement]
|
||||
service: Service, attribute_ids: Iterable[DataElement]
|
||||
) -> DataElement:
|
||||
attributes = []
|
||||
for attribute_id in attribute_ids:
|
||||
@@ -1053,30 +1184,24 @@ class Server:
|
||||
|
||||
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(request.service_search_pattern)
|
||||
service_record_handles = list(matching_services.keys())
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
|
||||
# Only return up to the maximum requested
|
||||
service_record_handles_subset = service_record_handles[
|
||||
: request.maximum_service_record_count
|
||||
]
|
||||
|
||||
# Serialize to a byte array, and remember the total count
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
self.current_response = (
|
||||
len(service_record_handles),
|
||||
service_record_handles_subset,
|
||||
@@ -1084,15 +1209,21 @@ class Server:
|
||||
|
||||
# Respond, keeping any unsent handles for later
|
||||
assert isinstance(self.current_response, tuple)
|
||||
service_record_handles = self.current_response[1][
|
||||
: request.maximum_service_record_count
|
||||
assert self.channel is not None
|
||||
total_service_record_count, service_record_handles = self.current_response
|
||||
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
|
||||
service_record_handles_remaining = service_record_handles[
|
||||
maximum_service_record_count:
|
||||
]
|
||||
service_record_handles = service_record_handles[:maximum_service_record_count]
|
||||
self.current_response = (
|
||||
self.current_response[0],
|
||||
self.current_response[1][request.maximum_service_record_count :],
|
||||
total_service_record_count,
|
||||
service_record_handles_remaining,
|
||||
)
|
||||
continuation_state = (
|
||||
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
|
||||
Server.CONTINUATION_STATE
|
||||
if service_record_handles_remaining
|
||||
else bytes([0])
|
||||
)
|
||||
service_record_handle_list = b''.join(
|
||||
[struct.pack('>I', handle) for handle in service_record_handles]
|
||||
@@ -1100,7 +1231,7 @@ class Server:
|
||||
self.send_response(
|
||||
SDP_ServiceSearchResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
total_service_record_count=self.current_response[0],
|
||||
total_service_record_count=total_service_record_count,
|
||||
current_service_record_count=len(service_record_handles),
|
||||
service_record_handle_list=service_record_handle_list,
|
||||
continuation_state=continuation_state,
|
||||
@@ -1111,19 +1242,14 @@ class Server:
|
||||
self, request: SDP_ServiceAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Check that the service exists
|
||||
service = self.service_records.get(request.service_record_handle)
|
||||
if service is None:
|
||||
@@ -1145,14 +1271,18 @@ class Server:
|
||||
self.current_response = bytes(attribute_list)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_list_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_list_byte_count=len(attribute_list_response),
|
||||
attribute_list=attribute_list,
|
||||
attribute_list=attribute_list_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
@@ -1161,18 +1291,14 @@ class Server:
|
||||
self, request: SDP_ServiceSearchAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(
|
||||
request.service_search_pattern
|
||||
@@ -1192,14 +1318,18 @@ class Server:
|
||||
self.current_response = bytes(attribute_lists)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_lists_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceSearchAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_lists_byte_count=len(attribute_lists_response),
|
||||
attribute_lists=attribute_lists,
|
||||
attribute_lists=attribute_lists_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
2
bumble/vendor/android/hci.py
vendored
2
bumble/vendor/android/hci.py
vendored
@@ -299,7 +299,7 @@ class HCI_Android_Vendor_Event(HCI_Extended_Event):
|
||||
|
||||
|
||||
HCI_Android_Vendor_Event.register_subevents(globals())
|
||||
HCI_Event.vendor_factory = HCI_Android_Vendor_Event.subclass_from_parameters
|
||||
HCI_Event.add_vendor_factory(HCI_Android_Vendor_Event.subclass_from_parameters)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
|
||||
from bumble.hci import (
|
||||
HCI_DISCONNECT_COMMAND,
|
||||
@@ -22,6 +23,7 @@ from bumble.hci import (
|
||||
HCI_LE_CODED_PHY_BIT,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_RESET_COMMAND,
|
||||
HCI_VENDOR_EVENT,
|
||||
HCI_SUCCESS,
|
||||
HCI_LE_CONNECTION_COMPLETE_EVENT,
|
||||
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT,
|
||||
@@ -67,6 +69,7 @@ from bumble.hci import (
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
HCI_Reset_Command,
|
||||
HCI_Set_Event_Mask_Command,
|
||||
HCI_Vendor_Event,
|
||||
)
|
||||
|
||||
|
||||
@@ -213,6 +216,41 @@ def test_HCI_Number_Of_Completed_Packets_Event():
|
||||
basic_check(event)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_HCI_Vendor_Event():
|
||||
data = bytes.fromhex('01020304')
|
||||
event = HCI_Vendor_Event(data=data)
|
||||
event_bytes = bytes(event)
|
||||
parsed = HCI_Packet.from_bytes(event_bytes)
|
||||
assert isinstance(parsed, HCI_Vendor_Event)
|
||||
assert parsed.data == data
|
||||
|
||||
class HCI_Custom_Event(HCI_Event):
|
||||
def __init__(self, blabla):
|
||||
super().__init__(HCI_VENDOR_EVENT, parameters=struct.pack("<I", blabla))
|
||||
self.name = 'HCI_CUSTOM_EVENT'
|
||||
self.blabla = blabla
|
||||
|
||||
def create_event(payload):
|
||||
if payload[0] == 1:
|
||||
return HCI_Custom_Event(blabla=struct.unpack('<I', payload)[0])
|
||||
return None
|
||||
|
||||
HCI_Event.add_vendor_factory(create_event)
|
||||
parsed = HCI_Packet.from_bytes(event_bytes)
|
||||
assert isinstance(parsed, HCI_Custom_Event)
|
||||
assert parsed.blabla == 0x04030201
|
||||
event_bytes2 = event_bytes[:3] + bytes([7]) + event_bytes[4:]
|
||||
parsed = HCI_Packet.from_bytes(event_bytes2)
|
||||
assert not isinstance(parsed, HCI_Custom_Event)
|
||||
assert isinstance(parsed, HCI_Vendor_Event)
|
||||
HCI_Event.remove_vendor_factory(create_event)
|
||||
|
||||
parsed = HCI_Packet.from_bytes(event_bytes)
|
||||
assert not isinstance(parsed, HCI_Custom_Event)
|
||||
assert isinstance(parsed, HCI_Vendor_Event)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_HCI_Command():
|
||||
command = HCI_Command(0x5566)
|
||||
@@ -576,6 +614,7 @@ def run_test_events():
|
||||
test_HCI_Command_Complete_Event()
|
||||
test_HCI_Command_Status_Event()
|
||||
test_HCI_Number_Of_Completed_Packets_Event()
|
||||
test_HCI_Vendor_Event()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -20,12 +20,11 @@ import logging
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
|
||||
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID
|
||||
from bumble.sdp import (
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
Client,
|
||||
Server,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
@@ -174,9 +173,10 @@ def test_data_elements() -> None:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def sdp_records():
|
||||
def sdp_records(record_count=1):
|
||||
return {
|
||||
0x00010001: [
|
||||
0x00010001
|
||||
+ i: [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(0x00010001),
|
||||
@@ -200,6 +200,7 @@ def sdp_records():
|
||||
),
|
||||
),
|
||||
]
|
||||
for i in range(record_count)
|
||||
}
|
||||
|
||||
|
||||
@@ -216,19 +217,55 @@ async def test_service_search():
|
||||
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||
|
||||
# Search for service
|
||||
client = Client(devices.connections[1])
|
||||
await client.connect()
|
||||
services = await client.search_services(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
|
||||
)
|
||||
async with Client(devices.connections[1]) as client:
|
||||
services = await client.search_services(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AF')]
|
||||
)
|
||||
assert len(services) == 0
|
||||
|
||||
# Then
|
||||
assert services[0] == 0x00010001
|
||||
services = await client.search_services(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
|
||||
)
|
||||
assert len(services) == 1
|
||||
assert services[0] == 0x00010001
|
||||
|
||||
services = await client.search_services(
|
||||
[BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT]
|
||||
)
|
||||
assert len(services) == 1
|
||||
assert services[0] == 0x00010001
|
||||
|
||||
services = await client.search_services(
|
||||
[BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT]
|
||||
)
|
||||
assert len(services) == 1
|
||||
assert services[0] == 0x00010001
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_attribute():
|
||||
async def test_service_search_with_continuation():
|
||||
# Setup connections
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
# Register SDP service
|
||||
records = sdp_records(100)
|
||||
devices.devices[0].sdp_server.service_records.update(records)
|
||||
|
||||
# Search for service
|
||||
async with Client(devices.connections[1], mtu=48) as client:
|
||||
services = await client.search_services(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
|
||||
)
|
||||
assert len(services) == len(records)
|
||||
for i in range(len(records)):
|
||||
assert services[i] == 0x00010001 + i
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_attributes():
|
||||
# Setup connections
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
@@ -236,15 +273,43 @@ async def test_service_attribute():
|
||||
# Register SDP service
|
||||
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||
|
||||
# Search for service
|
||||
client = Client(devices.connections[1])
|
||||
await client.connect()
|
||||
attributes = await client.get_attributes(
|
||||
0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
|
||||
)
|
||||
# Get attributes
|
||||
async with Client(devices.connections[1]) as client:
|
||||
attributes = await client.get_attributes(0x00010001, [1234])
|
||||
assert len(attributes) == 0
|
||||
|
||||
# Then
|
||||
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value
|
||||
attributes = await client.get_attributes(
|
||||
0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
|
||||
)
|
||||
assert len(attributes) == 1
|
||||
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_attributes_with_continuation():
|
||||
# Setup connections
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
# Register SDP service
|
||||
records = {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
x,
|
||||
DataElement.unsigned_integer_32(0x00010001),
|
||||
)
|
||||
for x in range(100)
|
||||
]
|
||||
}
|
||||
devices.devices[0].sdp_server.service_records.update(records)
|
||||
|
||||
# Get attributes
|
||||
async with Client(devices.connections[1], mtu=48) as client:
|
||||
attributes = await client.get_attributes(0x00010001, list(range(100)))
|
||||
assert len(attributes) == 100
|
||||
for i, attribute in enumerate(attributes):
|
||||
assert attribute.id == i
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -255,19 +320,81 @@ async def test_service_search_attribute():
|
||||
await devices.setup_connection()
|
||||
|
||||
# Register SDP service
|
||||
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||
records = {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
4,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
3,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
1,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
|
||||
),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
devices.devices[0].sdp_server.service_records.update(records)
|
||||
|
||||
# Search for service
|
||||
client = Client(devices.connections[1])
|
||||
await client.connect()
|
||||
attributes = await client.search_attributes(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)]
|
||||
)
|
||||
async with Client(devices.connections[1]) as client:
|
||||
attributes = await client.search_attributes(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)]
|
||||
)
|
||||
assert len(attributes) == 1
|
||||
assert len(attributes[0]) == 3
|
||||
assert attributes[0][0].id == 1
|
||||
assert attributes[0][1].id == 3
|
||||
assert attributes[0][2].id == 4
|
||||
|
||||
# Then
|
||||
for expect, actual in zip(attributes, sdp_records().values()):
|
||||
assert expect.id == actual.id
|
||||
assert expect.value == actual.value
|
||||
attributes = await client.search_attributes(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [1, 2, 3]
|
||||
)
|
||||
assert len(attributes) == 1
|
||||
assert len(attributes[0]) == 2
|
||||
assert attributes[0][0].id == 1
|
||||
assert attributes[0][1].id == 3
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_search_attribute_with_continuation():
|
||||
# Setup connections
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
# Register SDP service
|
||||
records = {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
x,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
|
||||
),
|
||||
)
|
||||
for x in range(100)
|
||||
]
|
||||
}
|
||||
devices.devices[0].sdp_server.service_records.update(records)
|
||||
|
||||
# Search for service
|
||||
async with Client(devices.connections[1], mtu=48) as client:
|
||||
attributes = await client.search_attributes(
|
||||
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)]
|
||||
)
|
||||
assert len(attributes) == 1
|
||||
assert len(attributes[0]) == 100
|
||||
for i in range(100):
|
||||
assert attributes[0][i].id == i
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -287,9 +414,12 @@ async def test_client_async_context():
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_data_elements()
|
||||
await test_service_attribute()
|
||||
await test_service_attributes()
|
||||
await test_service_attributes_with_continuation()
|
||||
await test_service_search()
|
||||
await test_service_search_with_continuation()
|
||||
await test_service_search_attribute()
|
||||
await test_service_search_attribute_with_continuation()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user