Compare commits

...

52 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
a311c3f723 hotfix for usb transport 2024-08-22 22:26:44 -07:00
zxzxwu
b2bb82a432 Merge pull request #537 from zxzxwu/smp
Ignore invalid RPA
2024-08-21 13:54:02 +08:00
Josh Wu
597560ff80 Ignore invalid local resolvable address 2024-08-21 00:11:14 +08:00
Gilles Boccon-Gibod
db383bb3e6 Merge pull request #531 from AlanRosenthal/btbench-scan
BtBench: Add Scan functionality
2024-08-14 11:59:13 -07:00
Alan Rosenthal
ccc5bbdad4 BtBench: Scan 2024-08-14 11:26:31 -04:00
zxzxwu
11c8229017 Merge pull request #533 from zxzxwu/hid
Correct HID type annotations
2024-08-14 12:08:53 +08:00
Josh Wu
2248f9ae5e Correct HID type annotations 2024-08-13 23:13:33 +08:00
Gilles Boccon-Gibod
03c79aacb2 Merge pull request #529 from google/gbg/broadcast-assistant
basic broadcast assistant functionality
2024-08-12 13:02:50 -07:00
zxzxwu
0c31713a8e Merge pull request #528 from zxzxwu/rpa
Fix CTKD failure introduced by Host RPA generation
2024-08-13 01:30:19 +08:00
Gilles Boccon-Gibod
9dd814f32e strict compliance check 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
ab6e595bcb fix typing 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f08fac8c8a catch ATT errors 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
a699520188 fix after rebase merge 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f66633459e wip 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f3b776c343 wip 2024-08-12 08:31:37 -07:00
Gilles Boccon-Gibod
de7b99ce34 wip 2024-08-12 08:29:32 -07:00
Gilles Boccon-Gibod
c0b17d9aff Merge pull request #530 from google/gbg/usb-no-parser
don't user a parser for a usb source
2024-08-12 08:21:19 -07:00
zxzxwu
3c12be59c5 Merge pull request #527 from zxzxwu/scan
Support Interlaced Scan config
2024-08-12 15:15:49 +08:00
Josh Wu
c6b3deb8df Fix CTKD failure introduced by Host RPA generation 2024-08-12 15:13:40 +08:00
Gilles Boccon-Gibod
a0b5606047 don't user a parser for a usb source 2024-08-11 20:57:45 -07:00
Josh Wu
3824e38485 Support Interlaced Scan config 2024-08-09 22:09:26 +08:00
Gilles Boccon-Gibod
4433184048 Merge pull request #522 from google/gbg/rpa2
add basic RPA support
2024-08-06 10:35:39 -07:00
Gilles Boccon-Gibod
312fc8db36 support controller-generated rpa 2024-08-05 08:59:05 -07:00
Gilles Boccon-Gibod
615691ec81 add basic RPA support 2024-08-01 15:37:11 -07:00
zxzxwu
ae8b83f294 Merge pull request #521 from zxzxwu/bap
Add Metadata LTV serializer and adapt Unicast
2024-07-31 11:36:46 +08:00
Josh Wu
4a8e21f4db Add Metadata LTV serializer and adapt Unicast 2024-07-31 01:20:28 +08:00
zxzxwu
3462e7c437 Merge pull request #439 from zxzxwu/mcp
Media Control Service Client implementation
2024-07-24 23:45:00 +08:00
Josh Wu
0f2e5239ad MCP constants and Client implementation 2024-07-24 22:57:26 +08:00
Gilles Boccon-Gibod
ee48cdc63f Merge pull request #517 from AlanRosenthal/scanner_pyee
Update scanner.py to use pyee.EventEmitter
2024-07-18 12:53:00 -07:00
Gilles Boccon-Gibod
1c278bec93 Merge pull request #518 from google/gbg/usb-queue
USB: better packet queue logic
2024-07-18 12:51:00 -07:00
Gilles Boccon-Gibod
6a51166af7 better packet queue logic 2024-07-17 17:48:26 -07:00
Alan Rosenthal
85d79fa914 Update scanner.py to use pyee.EventEmitter 2024-07-17 16:53:50 -04:00
zxzxwu
142bdce94a Merge pull request #515 from zxzxwu/unix
Add UNIX socket transport
2024-07-17 16:04:38 +08:00
Josh Wu
881a5a64b5 Add UNIX socket transport 2024-07-17 00:41:04 +08:00
zxzxwu
5aae44b610 Merge pull request #501 from zxzxwu/exception
Reorganize exceptions
2024-07-12 15:44:58 +08:00
Gilles Boccon-Gibod
e3ea167827 Merge pull request #506 from google/gbg/a2dp-fixes
a2dp: emit delay_report
2024-07-11 18:46:06 -07:00
Gilles Boccon-Gibod
eec145e095 add type hint 2024-07-11 18:39:02 -07:00
Gilles Boccon-Gibod
87fa02d6e5 Merge pull request #507 from google/packageFile
Create `inv web.build`
2024-07-11 18:35:29 -07:00
Gilles Boccon-Gibod
ad94c1e1f3 Merge pull request #509 from AlanRosenthal/discover
device.py: Add discover_all() api
2024-07-11 18:34:29 -07:00
Gilles Boccon-Gibod
546a0bce8d Merge pull request #510 from AlanRosenthal/get_characteristics_by_uuid
device.py: Update get_characteristics_by_uuid()
2024-07-11 18:33:45 -07:00
Gilles Boccon-Gibod
cb7ca44a1c Merge pull request #512 from AlanRosenthal/favicon
Add favicon.ico to docs folder
2024-07-11 18:27:19 -07:00
Gilles Boccon-Gibod
4081b93407 Merge pull request #513 from AlanRosenthal/devcontainer
Add devcontainer.json
2024-07-11 18:24:09 -07:00
Alan Rosenthal
26203ebaad Add devcontainer.json
devcontainer.json allows github's codespaces to be created with bumble's dependencies already installed
2024-07-11 18:47:32 +00:00
Alan Rosenthal
3389e3e1ed device.py: Update get_characteristics_by_uuid()
`get_characteristics_by_uuid()` now allows a UUID to be passed to the
service param. This allows for users to easily query for a service uuid
and characteristic uuid with one API.
2024-07-11 18:05:41 +00:00
Alan Rosenthal
7e1f01c01e Add favicon.ico to docs folder
Generated via: realfavicongenerator.net

validated via:
```
$ icotool -l favicon.ico
--icon --index=1 --width=48 --height=48 --bit-depth=32 --palette-size=0
--icon --index=2 --width=32 --height=32 --bit-depth=32 --palette-size=0
--icon --index=3 --width=16 --height=16 --bit-depth=32 --palette-size=0
```
2024-07-11 09:47:19 -04:00
Gilles Boccon-Gibod
613e15548a Merge pull request #511 from AlanRosenthal/random
console.py: Use Address.generate_static_address
2024-07-10 13:45:52 -07:00
Alan Rosenthal
e09c91df8e console.py: Use Address.generate_static_address 2024-07-10 18:51:46 +00:00
Alan Rosenthal
df206667b6 device.py: Add discover_all() api 2024-07-10 13:24:08 -04:00
Gilles Boccon-Gibod
0f19dd5263 Merge pull request #508 from google/web-readme
Add tip about disabling caching to web's readme
2024-07-09 09:17:25 -07:00
Alan Rosenthal
b98e4937f3 Add tip about disabling caching to web's readme 2024-07-09 13:48:53 +00:00
Gilles Boccon-Gibod
27791cf218 emit delay_report 2024-07-03 13:51:15 -07:00
Josh Wu
f8a2d4f0e0 Reorganize exceptions
* Add BaseBumbleException as a "real" root error
* Add several core error classes and properly replace builtin errors
  with them
* Add several error classes for specific modules (transport, device)
2024-06-11 16:13:08 +08:00
65 changed files with 4470 additions and 1416 deletions

View File

@@ -0,0 +1,30 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/universal:2",
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand":
"python -m pip install '.[build,test,development,documentation]'",
// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python"
]
}
}
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -17,10 +17,11 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import dataclasses import dataclasses
import logging import logging
import os import os
from typing import cast, Dict, Optional, Tuple from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
import click import click
import pyee import pyee
@@ -32,6 +33,7 @@ import bumble.device
import bumble.gatt import bumble.gatt
import bumble.hci import bumble.hci
import bumble.profiles.bap import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp import bumble.profiles.pbp
import bumble.transport import bumble.transport
import bumble.utils import bumble.utils
@@ -46,14 +48,16 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast" AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5") AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Discover Broadcasts # Scan For Broadcasts
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BroadcastDiscoverer: class BroadcastScanner(pyee.EventEmitter):
@dataclasses.dataclass @dataclasses.dataclass
class Broadcast(pyee.EventEmitter): class Broadcast(pyee.EventEmitter):
name: str name: str
@@ -79,22 +83,6 @@ class BroadcastDiscoverer:
self.sync.on('periodic_advertisement', self.on_periodic_advertisement) self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement) self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None: def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all( for service_data in advertisement.data.get_all(
@@ -139,6 +127,8 @@ class BroadcastDiscoverer:
data, data,
) )
self.emit('update')
def print(self) -> None: def print(self) -> None:
print( print(
color('Broadcast:', 'yellow'), color('Broadcast:', 'yellow'),
@@ -227,13 +217,12 @@ class BroadcastDiscoverer:
) )
def on_sync_establishment(self) -> None: def on_sync_establishment(self) -> None:
self.establishment_timeout_task.cancel() self.emit('sync_establishment')
self.emit('change')
def on_sync_loss(self) -> None: def on_sync_loss(self) -> None:
self.basic_audio_announcement = None self.basic_audio_announcement = None
self.biginfo = None self.biginfo = None
self.emit('change') self.emit('sync_loss')
def on_periodic_advertisement( def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement self, advertisement: bumble.device.PeriodicAdvertisement
@@ -268,37 +257,21 @@ class BroadcastDiscoverer:
filter_duplicates: bool, filter_duplicates: bool,
sync_timeout: float, sync_timeout: float,
): ):
super().__init__()
self.device = device self.device = device
self.filter_duplicates = filter_duplicates self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {} self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement) device.on('advertisement', self.on_advertisement)
async def run(self) -> None: async def start(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning( await self.device.start_scanning(
active=False, active=False,
filter_duplicates=False, filter_duplicates=False,
) )
def refresh(self) -> None: async def stop(self) -> None:
# Clear the screen from the top await self.device.stop_scanning()
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if ( if (
@@ -311,7 +284,6 @@ class BroadcastDiscoverer:
if broadcast := self.broadcasts.get(advertisement.address): if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement) broadcast.update(advertisement)
self.refresh()
return return
bumble.utils.AsyncRunner.spawn( bumble.utils.AsyncRunner.spawn(
@@ -331,41 +303,318 @@ class BroadcastDiscoverer:
name, name,
periodic_advertising_sync, periodic_advertising_sync,
) )
broadcast.on('change', self.refresh)
broadcast.update(advertisement) broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast)) periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.status_message = color( self.emit('new_broadcast', broadcast)
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None: def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address] del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate()) bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner:
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
self.scanner.on('update', self.refresh)
self.status_message = ''
async def start(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.scanner.start()
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color( self.status_message = color(
f'-Found {len(self.broadcasts)} broadcasts', 'green' f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
)
broadcast.on('change', self.refresh)
broadcast.on('update', self.refresh)
self.refresh()
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
) )
self.refresh() self.refresh()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
async def run_discover_broadcasts( # Print the status message
filter_duplicates: bool, sync_timeout: float, transport: str print(self.status_message)
) -> None: print("==========================================")
# Print all broadcasts
for broadcast in self.scanner.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
@contextlib.asynccontextmanager
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
async with await bumble.transport.open_transport(transport) as ( async with await bumble.transport.open_transport(transport) as (
hci_source, hci_source,
hci_sink, hci_sink,
): ):
device = bumble.device.Device.with_hci( device_config = bumble.device.DeviceConfiguration(
AURACAST_DEFAULT_DEVICE_NAME, name=AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS, address=AURACAST_DEFAULT_DEVICE_ADDRESS,
keystore='JsonKeyStore',
)
device = bumble.device.Device.from_config_with_hci(
device_config,
hci_source, hci_source,
hci_sink, hci_sink,
) )
await device.power_on() await device.power_on()
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run() yield device
await hci_source.terminated
async def find_broadcast_by_name(
device: bumble.device.Device, name: Optional[str]
) -> BroadcastScanner.Broadcast:
result = asyncio.get_running_loop().create_future()
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
if broadcast.basic_audio_announcement and not result.done():
print(color('Broadcast basic audio announcement received', 'green'))
result.set_result(broadcast)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if name is None or broadcast.name == name:
print(color('Broadcast found:', 'green'), broadcast.name)
broadcast.on('change', lambda: on_broadcast_change(broadcast))
return
print(color(f'Skipping broadcast {broadcast.name}'))
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
broadcast = await result
await scanner.stop()
return broadcast
async def run_scan(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
await scanner.start()
await asyncio.get_running_loop().create_future()
async def run_assist(
broadcast_name: Optional[str],
source_id: Optional[int],
command: str,
transport: str,
address: str,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
# Connect to the server
print(f'=== Connecting to {address}...')
connection = await device.connect(address)
peer = bumble.device.Peer(connection)
print(f'=== Connected to {peer}')
print("+++ Encrypting connection...")
await peer.connection.encrypt()
print("+++ Connection encrypted")
# Request a larger MTU
mtu = AURACAST_DEFAULT_ATT_MTU
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
await peer.request_mtu(mtu)
# Get the BASS service
bass = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy
)
# Check that the service was found
if not bass:
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return
# Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
try:
await broadcast_receive_state.subscribe(
lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
)
)
except bumble.core.ProtocolError as error:
print(
color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
'red',
),
error,
)
value = await broadcast_receive_state.read_value()
print(
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
)
if command == 'monitor-state':
await peer.sustain()
return
if command == 'add-source':
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source(
broadcast.sync.advertiser_address,
broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
# Initiate a PA Sync Transfer
await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning.
await bass.remote_scan_stopped()
await peer.sustain()
return
if command == 'modify-source':
if source_id is None:
print(color('!!! modify-source requires --source-id'))
return
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Modify the source
print(
color('Modifying source:', 'blue'),
source_id,
)
await bass.modify_source(
source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
await peer.sustain()
return
if command == 'remove-source':
if source_id is None:
print(color('!!! remove-source requires --source-id'))
return
# Remove the source
print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id)
await peer.sustain()
return
print(color(f'!!! invalid command {command}'))
async def run_pair(transport: str, address: str) -> None:
async with create_device(transport) as device:
# Connect to the server
print(f'=== Connecting to {address}...')
async with device.connect_as_gatt(address) as peer:
print(f'=== Connected to {peer}')
print("+++ Initiating pairing...")
await peer.connection.pair()
print("+++ Paired")
def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
except bumble.core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError
):
message = bumble.profiles.bass.ApplicationError(error.error_code).name
else:
message = str(error)
print(
color('!!! An error occurred while executing the command:', 'red'), message
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -379,7 +628,7 @@ def auracast(
ctx.ensure_object(dict) ctx.ensure_object(dict)
@auracast.command('discover-broadcasts') @auracast.command('scan')
@click.option( @click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates' '--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
) )
@@ -387,14 +636,50 @@ def auracast(
'--sync-timeout', '--sync-timeout',
metavar='SYNC_TIMEOUT', metavar='SYNC_TIMEOUT',
type=float, type=float,
default=5.0, default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)', help='Sync timeout (in seconds)',
) )
@click.argument('transport') @click.argument('transport')
@click.pass_context @click.pass_context
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport): def scan(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts""" """Scan for public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport)) run_async(run_scan(filter_duplicates, sync_timeout, transport))
@auracast.command('assist')
@click.option(
'--broadcast-name',
metavar='BROADCAST_NAME',
help='Broadcast Name to tune to',
)
@click.option(
'--source-id',
metavar='SOURCE_ID',
type=int,
help='Source ID (for remove-source command)',
)
@click.option(
'--command',
type=click.Choice(
['monitor-state', 'add-source', 'modify-source', 'remove-source']
),
required=True,
)
@click.argument('transport')
@click.argument('address')
@click.pass_context
def assist(ctx, broadcast_name, source_id, command, transport, address):
"""Scan for broadcasts on behalf of a audio server"""
run_async(run_assist(broadcast_name, source_id, command, transport, address))
@auracast.command('pair')
@click.argument('transport')
@click.argument('address')
@click.pass_context
def pair(ctx, transport, address):
"""Pair with an audio server"""
run_async(run_pair(transport, address))
def main(): def main():

View File

@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.gatt_client import CharacteristicProxy from bumble.gatt_client import CharacteristicProxy
from bumble.hci import ( from bumble.hci import (
Address,
HCI_Constant, HCI_Constant,
HCI_LE_1M_PHY, HCI_LE_1M_PHY,
HCI_LE_2M_PHY, HCI_LE_2M_PHY,
@@ -289,11 +290,7 @@ class ConsoleApp:
device_config, hci_source, hci_sink device_config, hci_source, hci_sink
) )
else: else:
random_address = ( random_address = Address.generate_static_address()
f"{random.randint(192,255):02X}" # address is static random
)
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
self.append_to_log(f"Setting random address: {random_address}") self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci( self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink 'Bumble', random_address, hci_source, hci_sink
@@ -503,21 +500,9 @@ class ConsoleApp:
self.show_error('not connected') self.show_error('not connected')
return return
# Discover all services, characteristics and descriptors self.append_to_output('Service Discovery starting...')
self.append_to_output('discovering services...') await self.connected_peer.discover_all()
await self.connected_peer.discover_services() self.append_to_output('Service Discovery done!')
self.append_to_output(
f'found {len(self.connected_peer.services)} services,'
' discovering characteristics...'
)
await self.connected_peer.discover_characteristics()
self.append_to_output('found characteristics, discovering descriptors...')
for service in self.connected_peer.services:
for characteristic in service.characteristics:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.show_remote_services(self.connected_peer.services)
async def discover_attributes(self): async def discover_attributes(self):
if not self.connected_peer: if not self.connected_peer:

230
apps/device_info.py Normal file
View File

@@ -0,0 +1,230 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from typing import Callable, Iterable, Optional
import click
from bumble.core import ProtocolError
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def try_show(function: Callable, *args, **kwargs) -> None:
try:
await function(*args, **kwargs)
except ProtocolError as error:
print(color('ERROR:', 'red'), error)
# -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
# -----------------------------------------------------------------------------
async def show_gap_information(
gap_service: GenericAccessServiceProxy,
):
print(color('### Generic Access Profile', 'yellow'))
if gap_service.device_name:
print(
color(' Device Name:', 'green'),
await gap_service.device_name.read_value(),
)
if gap_service.appearance:
print(
color(' Appearance: ', 'green'),
await gap_service.appearance.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_information(
device_information_service: DeviceInformationServiceProxy,
):
print(color('### Device Information', 'yellow'))
if device_information_service.manufacturer_name:
print(
color(' Manufacturer Name:', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number:
print(
color(' Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number:
print(
color(' Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.firmware_revision:
print(
color(' Firmware Revision:', 'green'),
await device_information_service.firmware_revision.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_battery_level(
battery_service: BatteryServiceProxy,
):
print(color('### Battery Information', 'yellow'))
if battery_service.battery_level:
print(
color(' Battery Level:', 'green'),
await battery_service.battery_level.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_tmas(
tmas: TelephonyAndMediaAudioServiceProxy,
):
print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role:
print(
color(' Role:', 'green'),
await tmas.role.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
await peer.discover_services()
for service in peer.services:
await service.discover_characteristics()
print(color('=== Services ===', 'yellow'))
show_services(peer.services)
print()
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
await try_show(show_gap_information, gap_service)
if device_information_service := peer.create_service_proxy(
DeviceInformationServiceProxy
):
await try_show(show_device_information, device_information_service)
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
await try_show(show_battery_level, battery_service)
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas)
if done is not None:
done.set_result(None)
except asyncio.CancelledError:
print(color('!!! Operation canceled', 'red'))
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await show_device_info(Peer(connection), None)
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on(
'connection',
lambda connection: asyncio.create_task(
show_device_info(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
await done
# -----------------------------------------------------------------------------
@click.command()
@click.option('--device-config', help='Device configuration', type=click.Path())
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
@click.argument('transport')
@click.argument('address-or-name', required=False)
def main(device_config, encrypt, transport, address_or_name):
"""
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -75,11 +75,15 @@ async def async_main(device_config, encrypt, transport, address_or_name):
if address_or_name: if address_or_name:
# Connect to the target peer # Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name) connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required # Encrypt the connection if required
if encrypt: if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt() await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await dump_gatt_db(Peer(connection), None) await dump_gatt_db(Peer(connection), None)
else: else:

View File

@@ -33,7 +33,6 @@ import ctypes
import wasmtime import wasmtime
import wasmtime.loader import wasmtime.loader
import liblc3 # type: ignore import liblc3 # type: ignore
import logging
import click import click
import aiohttp.web import aiohttp.web
@@ -43,7 +42,7 @@ from bumble.core import AdvertisingData
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles import bap from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -57,8 +56,8 @@ logger = logging.getLogger(__name__)
DEFAULT_UI_PORT = 7654 DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> bap.PacRecord: def _sink_pac_record() -> pacs.PacRecord:
return bap.PacRecord( return pacs.PacRecord(
coding_format=CodingFormat(CodecID.LC3), coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities( codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=( supported_sampling_frequencies=(
@@ -79,8 +78,8 @@ def _sink_pac_record() -> bap.PacRecord:
) )
def _source_pac_record() -> bap.PacRecord: def _source_pac_record() -> pacs.PacRecord:
return bap.PacRecord( return pacs.PacRecord(
coding_format=CodingFormat(CodecID.LC3), coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities( codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=( supported_sampling_frequencies=(
@@ -447,7 +446,7 @@ class Speaker:
) )
self.device.add_service( self.device.add_service(
bap.PublishedAudioCapabilitiesService( pacs.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF), supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF), available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types supported_sink_context=bap.ContextType(0xFFFF), # All context types
@@ -461,10 +460,10 @@ class Speaker:
) )
) )
ascs = bap.AudioStreamControlService( ascs_service = ascs.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2] self.device, sink_ase_id=[1], source_ase_id=[2]
) )
self.device.add_service(ascs) self.device.add_service(ascs_service)
advertising_data = bytes( advertising_data = bytes(
AdvertisingData( AdvertisingData(
@@ -479,13 +478,13 @@ class Speaker:
), ),
( (
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(bap.PublishedAudioCapabilitiesService.UUID), bytes(pacs.PublishedAudioCapabilitiesService.UUID),
), ),
] ]
) )
) + bytes(bap.UnicastServerAdvertisingData()) ) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
codec_config = ase.codec_specific_configuration codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration) assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode( pcm = decode(
@@ -495,12 +494,12 @@ class Speaker:
) )
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: bap.AseStateMachine) -> None: def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING: if ase.state == ascs.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration) assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link assert ase.cis_link
if ase.role == bap.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
ase.cis_link.abort_on( ase.cis_link.abort_on(
'disconnection', 'disconnection',
lc3_source_task( lc3_source_task(
@@ -516,10 +515,10 @@ class Speaker:
) )
else: else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase) ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED: elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration) assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == bap.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
setup_encoders( setup_encoders(
codec_config.sampling_frequency.hz, codec_config.sampling_frequency.hz,
codec_config.frame_duration.us, codec_config.frame_duration.us,
@@ -532,7 +531,7 @@ class Speaker:
codec_config.audio_channel_allocation.channel_count, codec_config.audio_channel_allocation.channel_count,
) )
for ase in ascs.ase_state_machines.values(): for ase in ascs_service.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on() await self.device.power_on()

View File

@@ -14,13 +14,19 @@
from typing import List, Union from typing import List, Union
from bumble import core
class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> List[bytes]: def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens. """Split input parameters into tokens.
Removes space characters outside of double quote blocks: Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0) T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants" are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = [] tokens = []
in_quotes = False in_quotes = False
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
token = bytearray() token = bytearray()
elif char == b'(': elif char == b'(':
if len(token) > 0: if len(token) > 0:
raise ValueError("open_paren following regular character") raise AtParsingError("open_paren following regular character")
tokens.append(char) tokens.append(char)
elif char == b'"': elif char == b'"':
if len(token) > 0: if len(token) > 0:
raise ValueError("quote following regular character") raise AtParsingError("quote following regular character")
in_quotes = True in_quotes = True
token.extend(char) token.extend(char)
else: else:
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]: def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators. """Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer) tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]] accumulator: List[list] = [[]]
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator.append([]) accumulator.append([])
elif token == b')': elif token == b')':
if len(accumulator) < 2: if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren") raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current) accumulator[-1].append(current)
current = accumulator.pop() current = accumulator.pop()
else: else:
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator[-1].append(current) accumulator[-1].append(current)
if len(accumulator) > 1: if len(accumulator) > 1:
raise ValueError("missing close_paren") raise AtParsingError("missing close_paren")
return accumulator[0] return accumulator[0]

View File

@@ -20,6 +20,7 @@ import enum
import struct import struct
from typing import Dict, Type, Union, Tuple from typing import Dict, Type, Union, Tuple
from bumble import core
from bumble.utils import OpenIntEnum from bumble.utils import OpenIntEnum
@@ -88,7 +89,9 @@ class Frame:
short_name = subclass.__name__.replace("ResponseFrame", "") short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame category_class = ResponseFrame
else: else:
raise ValueError(f"invalid subclass name {subclass.__name__}") raise core.InvalidArgumentError(
f"invalid subclass name {subclass.__name__}"
)
uppercase_indexes = [ uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper() i for i in range(len(short_name)) if short_name[i].isupper()
@@ -106,7 +109,7 @@ class Frame:
@staticmethod @staticmethod
def from_bytes(data: bytes) -> Frame: def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0: if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s") raise core.InvalidPacketError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3) subunit_type = Frame.SubunitType(data[1] >> 3)
@@ -122,7 +125,7 @@ class Frame:
# Extended to the next byte # Extended to the next byte
extension = data[2] extension = data[2]
if extension == 0: if extension == 0:
raise ValueError("extended subunit ID value reserved") raise core.InvalidPacketError("extended subunit ID value reserved")
if extension == 0xFF: if extension == 0xFF:
subunit_id = 5 + 254 + data[3] subunit_id = 5 + 254 + data[3]
opcode_offset = 4 opcode_offset = 4
@@ -131,7 +134,7 @@ class Frame:
opcode_offset = 3 opcode_offset = 3
elif subunit_id == 6: elif subunit_id == 6:
raise ValueError("reserved subunit ID") raise core.InvalidPacketError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset]) opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :] operands = data[opcode_offset + 1 :]
@@ -448,7 +451,7 @@ class PassThroughFrame:
operation_data: bytes, operation_data: bytes,
) -> None: ) -> None:
if len(operation_data) > 255: if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes") raise core.InvalidArgumentError("operation data must be <= 255 bytes")
self.state_flag = state_flag self.state_flag = state_flag
self.operation_id = operation_id self.operation_id = operation_id
self.operation_data = operation_data self.operation_data = operation_data

View File

@@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional
from bumble.colors import color from bumble.colors import color
from bumble import avc from bumble import avc
from bumble import core
from bumble import l2cap from bumble import l2cap
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -275,7 +276,7 @@ class Protocol:
self, pid: int, handler: Protocol.CommandHandler self, pid: int, handler: Protocol.CommandHandler
) -> None: ) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler: if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered") raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid] del self.command_handlers[pid]
def register_response_handler( def register_response_handler(
@@ -287,5 +288,5 @@ class Protocol:
self, pid: int, handler: Protocol.ResponseHandler self, pid: int, handler: Protocol.ResponseHandler
) -> None: ) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler: if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered") raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid] del self.response_handlers[pid]

View File

@@ -43,6 +43,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError, InvalidStateError,
ProtocolError, ProtocolError,
InvalidArgumentError,
name_or_number, name_or_number,
) )
from .a2dp import ( from .a2dp import (
@@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
signal_identifier_str = name[:-7] signal_identifier_str = name[:-7]
message_type = Message.MessageType.RESPONSE_REJECT message_type = Message.MessageType.RESPONSE_REJECT
else: else:
raise ValueError('invalid class name') raise InvalidArgumentError('invalid class name')
subclass.message_type = message_type subclass.message_type = message_type
@@ -2162,6 +2163,9 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
def on_abort_command(self): def on_abort_command(self):
self.emit('abort') self.emit('abort')
def on_delayreport_command(self, delay: int):
self.emit('delay_report', delay)
def on_rtp_channel_open(self): def on_rtp_channel_open(self):
self.emit('rtp_channel_open') self.emit('rtp_channel_open')

View File

@@ -55,6 +55,7 @@ from bumble.sdp import (
) )
from bumble.utils import AsyncRunner, OpenIntEnum from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.core import ( from bumble.core import (
InvalidArgumentError,
ProtocolError, ProtocolError,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_AVCTP_PROTOCOL_ID, BT_AVCTP_PROTOCOL_ID,
@@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter):
def notify_track_changed(self, identifier: bytes) -> None: def notify_track_changed(self, identifier: bytes) -> None:
"""Notify the connected peer of a Track change.""" """Notify the connected peer of a Track change."""
if len(identifier) != 8: if len(identifier) != 8:
raise ValueError("identifier must be 8 bytes") raise InvalidArgumentError("identifier must be 8 bytes")
self.notify_event(TrackChangedEvent(identifier)) self.notify_event(TrackChangedEvent(identifier))
def notify_playback_position_changed(self, position: int) -> None: def notify_playback_position_changed(self, position: int) -> None:

View File

@@ -18,6 +18,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from bumble import core
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BitReader: class BitReader:
@@ -40,7 +42,7 @@ class BitReader:
""" "Read up to 32 bits.""" """ "Read up to 32 bits."""
if bits > 32: if bits > 32:
raise ValueError('maximum read size is 32') raise core.InvalidArgumentError('maximum read size is 32')
if self.bits_cached >= bits: if self.bits_cached >= bits:
# We have enough bits. # We have enough bits.
@@ -53,7 +55,7 @@ class BitReader:
feed_size = len(feed_bytes) feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big') feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits: if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data') raise core.InvalidArgumentError('trying to read past the data')
self.byte_position += feed_size self.byte_position += feed_size
# Combine the new cache and the old cache # Combine the new cache and the old cache
@@ -68,7 +70,7 @@ class BitReader:
def read_bytes(self, count: int): def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data): if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data') raise core.InvalidArgumentError('not enough data')
if self.bit_position % 8: if self.bit_position % 8:
# Not byte aligned # Not byte aligned
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
@staticmethod @staticmethod
def program_config_element(reader: BitReader): def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported') raise core.InvalidPacketError('program_config_element not supported')
@dataclass @dataclass
class GASpecificConfig: class GASpecificConfig:
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
aac_spectral_data_resilience_flags = reader.read(1) aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1) extension_flag_3 = reader.read(1)
if extension_flag_3 == 1: if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported') raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod @staticmethod
def audio_object_type(reader: BitReader): def audio_object_type(reader: BitReader):
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
reader, self.channel_configuration, self.audio_object_type reader, self.channel_configuration, self.audio_object_type
) )
else: else:
raise ValueError( raise core.InvalidPacketError(
f'audioObjectType {self.audio_object_type} not supported' f'audioObjectType {self.audio_object_type} not supported'
) )
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
else: else:
audio_mux_version_a = 0 audio_mux_version_a = 0
if audio_mux_version_a != 0: if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported') raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1: if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader) tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0 stream_cnt = 0
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
num_sub_frames = reader.read(6) num_sub_frames = reader.read(6)
num_program = reader.read(4) num_program = reader.read(4)
if num_program != 0: if num_program != 0:
raise ValueError('num_program != 0 not supported') raise core.InvalidPacketError('num_program != 0 not supported')
num_layer = reader.read(3) num_layer = reader.read(3)
if num_layer != 0: if num_layer != 0:
raise ValueError('num_layer != 0 not supported') raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0: if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig( self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader reader
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
) )
audio_specific_config_len = reader.bit_position - marker audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len: if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len') raise core.InvalidPacketError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len asc_len -= audio_specific_config_len
reader.skip(asc_len) reader.skip(asc_len)
frame_length_type = reader.read(3) frame_length_type = reader.read(3)
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
elif frame_length_type == 1: elif frame_length_type == 1:
frame_length = reader.read(9) frame_length = reader.read(9)
else: else:
raise ValueError(f'frame_length_type {frame_length_type} not supported') raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1) self.other_data_present = reader.read(1)
if self.other_data_present: if self.other_data_present:
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
def __init__(self, reader: BitReader, mux_config_present: int): def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0: if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported') raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41 # AudioMuxElement - ISO/EIC 14496-3 Table 1.41
use_same_stream_mux = reader.read(1) use_same_stream_mux = reader.read(1)
if use_same_stream_mux: if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported') raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader) self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support: # We only support:

View File

@@ -16,6 +16,10 @@ from functools import partial
from typing import List, Optional, Union from typing import List, Optional, Union
class ColorError(ValueError):
"""Error raised when a color spec is invalid."""
# ANSI color names. There is also a "default" # ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255: elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec) return _join(base + 8, 5, spec)
else: else:
raise ValueError('Invalid color spec "%s"' % spec) raise ColorError('Invalid color spec "%s"' % spec)
def color( def color(
@@ -72,7 +76,7 @@ def color(
if style_part in STYLES: if style_part in STYLES:
codes.append(STYLES.index(style_part)) codes.append(STYLES.index(style_part))
else: else:
raise ValueError('Invalid style "%s"' % style_part) raise ColorError('Invalid style "%s"' % style_part)
if codes: if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s) return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)

View File

@@ -79,7 +79,13 @@ def get_dict_key_by_value(dictionary, value):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BaseError(Exception):
class BaseBumbleError(Exception):
"""Base Error raised by Bumble."""
class BaseError(BaseBumbleError):
"""Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__( def __init__(
@@ -118,18 +124,42 @@ class ProtocolError(BaseError):
"""Protocol Error""" """Protocol Error"""
class TimeoutError(Exception): # pylint: disable=redefined-builtin class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin
"""Timeout Error""" """Timeout Error"""
class CommandTimeoutError(Exception): class CommandTimeoutError(BaseBumbleError):
"""Command Timeout Error""" """Command Timeout Error"""
class InvalidStateError(Exception): class InvalidStateError(BaseBumbleError):
"""Invalid State Error""" """Invalid State Error"""
class InvalidArgumentError(BaseBumbleError, ValueError):
"""Invalid Argument Error"""
class InvalidPacketError(BaseBumbleError, ValueError):
"""Invalid Packet Error"""
class InvalidOperationError(BaseBumbleError, RuntimeError):
"""Invalid Operation Error"""
class NotSupportedError(BaseBumbleError, RuntimeError):
"""Not Supported"""
class OutOfResourcesError(BaseBumbleError, RuntimeError):
"""Out of Resources Error"""
class UnreachableError(BaseBumbleError):
"""The code path raising this error should be unreachable."""
class ConnectionError(BaseError): # pylint: disable=redefined-builtin class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error""" """Connection Error"""
@@ -188,12 +218,12 @@ class UUID:
or uuid_str_or_int[18] != '-' or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-' or uuid_str_or_int[23] != '-'
): ):
raise ValueError('invalid UUID format') raise InvalidArgumentError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '') uuid_str = uuid_str_or_int.replace('-', '')
else: else:
uuid_str = uuid_str_or_int uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4: if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError(f"invalid UUID format: {uuid_str}") raise InvalidArgumentError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name self.name = name
@@ -218,7 +248,7 @@ class UUID:
return self.register() return self.register()
raise ValueError('only 2, 4 and 16 bytes are allowed') raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
@classmethod @classmethod
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID: def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:

View File

@@ -27,6 +27,7 @@ import copy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, IntEnum from enum import Enum, IntEnum
import functools import functools
import itertools
import json import json
import logging import logging
import secrets import secrets
@@ -50,6 +51,7 @@ from typing_extensions import Self
from pyee import EventEmitter from pyee import EventEmitter
from bumble import hci
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from .gatt import Characteristic, Descriptor, Service from .gatt import Characteristic, Descriptor, Service
@@ -111,6 +113,7 @@ from .hci import (
HCI_LE_Periodic_Advertising_Create_Sync_Command, HCI_LE_Periodic_Advertising_Create_Sync_Command,
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command, HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command,
HCI_LE_Periodic_Advertising_Report_Event, HCI_LE_Periodic_Advertising_Report_Event,
HCI_LE_Periodic_Advertising_Sync_Transfer_Command,
HCI_LE_Periodic_Advertising_Terminate_Sync_Command, HCI_LE_Periodic_Advertising_Terminate_Sync_Command,
HCI_LE_Enable_Encryption_Command, HCI_LE_Enable_Encryption_Command,
HCI_LE_Extended_Advertising_Report_Event, HCI_LE_Extended_Advertising_Report_Event,
@@ -167,21 +170,29 @@ from .hci import (
OwnAddressType, OwnAddressType,
LeFeature, LeFeature,
LeFeatureMask, LeFeatureMask,
LmpFeatureMask,
Phy, Phy,
phy_list_to_bits, phy_list_to_bits,
) )
from .host import Host from .host import Host
from .gap import GenericAccessService from .profiles.gap import GenericAccessService
from .core import ( from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
AdvertisingData, AdvertisingData,
BaseBumbleError,
ConnectionParameterUpdateError, ConnectionParameterUpdateError,
CommandTimeoutError, CommandTimeoutError,
ConnectionParameters,
ConnectionPHY, ConnectionPHY,
InvalidArgumentError,
InvalidOperationError,
InvalidStateError, InvalidStateError,
NotSupportedError,
OutOfResourcesError,
UnreachableError,
) )
from .utils import ( from .utils import (
AsyncRunner, AsyncRunner,
@@ -196,13 +207,13 @@ from .keys import (
KeyStore, KeyStore,
PairingKeys, PairingKeys,
) )
from .pairing import PairingConfig from bumble import pairing
from . import gatt_client from bumble import gatt_client
from . import gatt_server from bumble import gatt_server
from . import smp from bumble import smp
from . import sdp from bumble import sdp
from . import l2cap from bumble import l2cap
from . import core from bumble import core
if TYPE_CHECKING: if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink from .transport.common import TransportSource, TransportSink
@@ -253,8 +264,9 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
DEVICE_DEFAULT_ADVERTISING_TX_POWER = ( DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
) )
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0 DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0 DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -266,6 +278,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ObjectLookupError(BaseBumbleError):
"""Error raised when failed to lookup an object."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -958,20 +972,25 @@ class PeriodicAdvertisingSync(EventEmitter):
response = await self.device.send_command( response = await self.device.send_command(
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(), HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(),
) )
if response.status == HCI_SUCCESS: if response.return_parameters == HCI_SUCCESS:
if self in self.device.periodic_advertising_syncs: if self in self.device.periodic_advertising_syncs:
self.device.periodic_advertising_syncs.remove(self) self.device.periodic_advertising_syncs.remove(self)
return return
if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST): if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST):
self.state = self.State.TERMINATED self.state = self.State.TERMINATED
await self.device.send_command( if self.sync_handle is not None:
HCI_LE_Periodic_Advertising_Terminate_Sync_Command( await self.device.send_command(
sync_handle=self.sync_handle HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
sync_handle=self.sync_handle
)
) )
)
self.device.periodic_advertising_syncs.remove(self) self.device.periodic_advertising_syncs.remove(self)
async def transfer(self, connection: Connection, service_data: int = 0) -> None:
if self.sync_handle is not None:
await connection.transfer_periodic_sync(self.sync_handle, service_data)
def on_establishment( def on_establishment(
self, self,
status, status,
@@ -1133,6 +1152,15 @@ class Peer:
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]: async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes() return await self.gatt_client.discover_attributes()
async def discover_all(self):
await self.discover_services()
for service in self.services:
await self.discover_characteristics(service=service)
for service in self.services:
for characteristic in service.characteristics:
await self.discover_descriptors(characteristic=characteristic)
async def subscribe( async def subscribe(
self, self,
characteristic: gatt_client.CharacteristicProxy, characteristic: gatt_client.CharacteristicProxy,
@@ -1172,12 +1200,29 @@ class Peer:
return self.gatt_client.get_services_by_uuid(uuid) return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid( def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None self,
uuid: core.UUID,
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
) -> List[gatt_client.CharacteristicProxy]: ) -> List[gatt_client.CharacteristicProxy]:
if isinstance(service, core.UUID):
return list(
itertools.chain(
*[
self.get_characteristics_by_uuid(uuid, s)
for s in self.get_services_by_uuid(service)
]
)
)
return self.gatt_client.get_characteristics_by_uuid(uuid, service) return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS: def create_service_proxy(
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client)) self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
if proxy := proxy_class.from_client(self.gatt_client):
return cast(_PROXY_CLASS, proxy)
return None
async def discover_service_and_create_proxy( async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS] self, proxy_class: Type[_PROXY_CLASS]
@@ -1274,6 +1319,7 @@ class Connection(CompositeEventEmitter):
handle: int handle: int
transport: int transport: int
self_address: Address self_address: Address
self_resolvable_address: Optional[Address]
peer_address: Address peer_address: Address
peer_resolvable_address: Optional[Address] peer_resolvable_address: Optional[Address]
peer_le_features: Optional[LeFeatureMask] peer_le_features: Optional[LeFeatureMask]
@@ -1321,6 +1367,7 @@ class Connection(CompositeEventEmitter):
handle, handle,
transport, transport,
self_address, self_address,
self_resolvable_address,
peer_address, peer_address,
peer_resolvable_address, peer_resolvable_address,
role, role,
@@ -1332,6 +1379,7 @@ class Connection(CompositeEventEmitter):
self.handle = handle self.handle = handle
self.transport = transport self.transport = transport
self.self_address = self_address self.self_address = self_address
self.self_resolvable_address = self_resolvable_address
self.peer_address = peer_address self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only self.peer_name = None # Classic only
@@ -1365,6 +1413,7 @@ class Connection(CompositeEventEmitter):
None, None,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
device.public_address, device.public_address,
None,
peer_address, peer_address,
None, None,
role, role,
@@ -1458,11 +1507,9 @@ class Connection(CompositeEventEmitter):
try: try:
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout) await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
except asyncio.TimeoutError: finally:
pass self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None: async def set_data_length(self, tx_octets, tx_time) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time) return await self.device.set_data_length(self, tx_octets, tx_time)
@@ -1493,6 +1540,11 @@ class Connection(CompositeEventEmitter):
async def get_phy(self): async def get_phy(self):
return await self.device.get_connection_phy(self) return await self.device.get_connection_phy(self)
async def transfer_periodic_sync(
self, sync_handle: int, service_data: int = 0
) -> None:
await self.device.transfer_periodic_sync(self, sync_handle, service_data)
# [Classic only] # [Classic only]
async def request_remote_name(self): async def request_remote_name(self):
return await self.device.request_remote_name(self) return await self.device.request_remote_name(self)
@@ -1523,7 +1575,9 @@ class Connection(CompositeEventEmitter):
f'Connection(handle=0x{self.handle:04X}, ' f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, ' f'role={self.role_name}, '
f'self_address={self.self_address}, ' f'self_address={self.self_address}, '
f'peer_address={self.peer_address})' f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
) )
@@ -1538,13 +1592,15 @@ class DeviceConfiguration:
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
le_enabled: bool = True le_enabled: bool = True
# LE host enable 2nd parameter
le_simultaneous_enabled: bool = False le_simultaneous_enabled: bool = False
le_privacy_enabled: bool = False
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
classic_enabled: bool = False classic_enabled: bool = False
classic_sc_enabled: bool = True classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True classic_ssp_enabled: bool = True
classic_smp_enabled: bool = True classic_smp_enabled: bool = True
classic_accept_any: bool = True classic_accept_any: bool = True
classic_interlaced_scan_enabled: bool = True
connectable: bool = True connectable: bool = True
discoverable: bool = True discoverable: bool = True
advertising_data: bytes = bytes( advertising_data: bytes = bytes(
@@ -1555,7 +1611,10 @@ class DeviceConfiguration:
irk: bytes = bytes(16) # This really must be changed for any level of security irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None keystore: Optional[str] = None
address_resolution_offload: bool = False address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = [] self.gatt_services: List[Dict[str, Any]] = []
@@ -1640,7 +1699,9 @@ def with_connection_from_handle(function):
@functools.wraps(function) @functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs): def wrapper(self, connection_handle, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None: if (connection := self.lookup_connection(connection_handle)) is None:
raise ValueError(f'no connection for handle: 0x{connection_handle:04x}') raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
)
return function(self, connection, *args, **kwargs) return function(self, connection, *args, **kwargs)
return wrapper return wrapper
@@ -1655,7 +1716,7 @@ def with_connection_from_address(function):
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address == address: if connection.peer_address == address:
return function(self, connection, *args, **kwargs) return function(self, connection, *args, **kwargs)
raise ValueError('no connection for address') raise ObjectLookupError('no connection for address')
return wrapper return wrapper
@@ -1705,8 +1766,9 @@ device_host_event_handlers: List[str] = []
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Device(CompositeEventEmitter): class Device(CompositeEventEmitter):
# Incomplete list of fields. # Incomplete list of fields.
random_address: Address random_address: Address # Random address that may change with RPA
public_address: Address public_address: Address # Public address (obtained from the controller)
static_address: Address # Random address that can be set but does not change
classic_enabled: bool classic_enabled: bool
name: str name: str
class_of_device: int class_of_device: int
@@ -1836,23 +1898,29 @@ class Device(CompositeEventEmitter):
config = config or DeviceConfiguration() config = config or DeviceConfiguration()
self.config = config self.config = config
self.public_address = Address('00:00:00:00:00:00')
self.name = config.name self.name = config.name
self.public_address = Address.ANY
self.random_address = config.address self.random_address = config.address
self.static_address = config.address
self.class_of_device = config.class_of_device self.class_of_device = config.class_of_device
self.keystore = None self.keystore = None
self.irk = config.irk self.irk = config.irk
self.le_enabled = config.le_enabled self.le_enabled = config.le_enabled
self.classic_enabled = config.classic_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
self.classic_enabled = config.classic_enabled
self.cis_enabled = config.cis_enabled self.cis_enabled = config.cis_enabled
self.classic_sc_enabled = config.classic_sc_enabled self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_smp_enabled = config.classic_smp_enabled self.classic_smp_enabled = config.classic_smp_enabled
self.classic_interlaced_scan_enabled = config.classic_interlaced_scan_enabled
self.discoverable = config.discoverable self.discoverable = config.discoverable
self.connectable = config.connectable self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload self.address_resolution_offload = config.address_resolution_offload
self.address_generation_offload = config.address_generation_offload
# Extended advertising. # Extended advertising.
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {} self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
@@ -1908,10 +1976,23 @@ class Device(CompositeEventEmitter):
if isinstance(address, str): if isinstance(address, str):
address = Address(address) address = Address(address)
self.random_address = address self.random_address = address
self.static_address = address
# Setup SMP # Setup SMP
self.smp_manager = smp.Manager( self.smp_manager = smp.Manager(
self, pairing_config_factory=lambda connection: PairingConfig() self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
) )
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu) self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2093,7 +2174,7 @@ class Device(CompositeEventEmitter):
spec=spec, spec=spec,
) )
else: else:
raise ValueError(f'Unexpected mode {spec}') raise InvalidArgumentError(f'Unexpected mode {spec}')
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu) self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -2135,26 +2216,26 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command( HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled), le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled), simultaneous_le_host=int(self.le_simultaneous_enabled),
) ),
check_result=True,
) )
if self.le_enabled: if self.le_enabled:
# Set the controller address # Generate a random address if not set.
if self.random_address == Address.ANY_RANDOM: if self.static_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller self.static_address = Address.generate_static_address()
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes # If LE Privacy is enabled, generate an RPA
response = await self.send_command( if self.le_privacy_enabled:
HCI_LE_Rand_Command(), check_result=True self.random_address = Address.generate_private_address(self.irk)
logger.info(f'Initial RPA: {self.random_address}')
if self.le_rpa_timeout > 0:
# Start a task to periodically generate a new RPA
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
) )
else:
# Ensure the address bytes can be a static random address self.random_address = self.static_address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
if self.random_address != Address.ANY_RANDOM: if self.random_address != Address.ANY_RANDOM:
logger.debug( logger.debug(
@@ -2179,7 +2260,8 @@ class Device(CompositeEventEmitter):
await self.send_command( await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command( HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1 address_resolution_enable=1
) ),
check_result=True,
) )
if self.cis_enabled: if self.cis_enabled:
@@ -2187,7 +2269,8 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command( HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM, bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1, bit_value=1,
) ),
check_result=True,
) )
if self.classic_enabled: if self.classic_enabled:
@@ -2210,6 +2293,21 @@ class Device(CompositeEventEmitter):
await self.set_connectable(self.connectable) await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable) await self.set_discoverable(self.discoverable)
if self.classic_interlaced_scan_enabled:
if self.host.supports_lmp_features(LmpFeatureMask.INTERLACED_PAGE_SCAN):
await self.send_command(
hci.HCI_Write_Page_Scan_Type_Command(page_scan_type=1),
check_result=True,
)
if self.host.supports_lmp_features(
LmpFeatureMask.INTERLACED_INQUIRY_SCAN
):
await self.send_command(
hci.HCI_Write_Inquiry_Scan_Type_Command(scan_type=1),
check_result=True,
)
# Done # Done
self.powered_on = True self.powered_on = True
@@ -2218,9 +2316,45 @@ class Device(CompositeEventEmitter):
async def power_off(self) -> None: async def power_off(self) -> None:
if self.powered_on: if self.powered_on:
if self.le_rpa_periodic_update_task:
self.le_rpa_periodic_update_task.cancel()
await self.host.flush() await self.host.flush()
self.powered_on = False self.powered_on = False
async def update_rpa(self) -> bool:
"""
Try to update the RPA.
Returns:
True if the RPA was updated, False if it could not be updated.
"""
# Check if this is a good time to rotate the address
if self.is_advertising or self.is_scanning or self.is_le_connecting:
logger.debug('skipping RPA update')
return False
random_address = Address.generate_private_address(self.irk)
response = await self.send_command(
HCI_LE_Set_Random_Address_Command(random_address=self.random_address)
)
if response.return_parameters == HCI_SUCCESS:
logger.info(f'new RPA: {random_address}')
self.random_address = random_address
return True
else:
logger.warning(f'failed to set RPA: {response.return_parameters}')
return False
async def _run_rpa_periodic_update(self) -> None:
"""Update the RPA periodically"""
while self.le_rpa_timeout != 0:
await asyncio.sleep(self.le_rpa_timeout)
if not self.update_rpa():
logger.debug("periodic RPA update failed")
async def refresh_resolving_list(self) -> None: async def refresh_resolving_list(self) -> None:
assert self.keystore is not None assert self.keystore is not None
@@ -2228,7 +2362,7 @@ class Device(CompositeEventEmitter):
# Create a host-side address resolver # Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys) self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload: if self.address_resolution_offload or self.address_generation_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) await self.send_command(HCI_LE_Clear_Resolving_List_Command())
# Add an empty entry for non-directed address generation. # Add an empty entry for non-directed address generation.
@@ -2254,7 +2388,7 @@ class Device(CompositeEventEmitter):
def supports_le_features(self, feature: LeFeatureMask) -> bool: def supports_le_features(self, feature: LeFeatureMask) -> bool:
return self.host.supports_le_features(feature) return self.host.supports_le_features(feature)
def supports_le_phy(self, phy): def supports_le_phy(self, phy: int) -> bool:
if phy == HCI_LE_1M_PHY: if phy == HCI_LE_1M_PHY:
return True return True
@@ -2263,7 +2397,7 @@ class Device(CompositeEventEmitter):
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY, HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
} }
if phy not in feature_map: if phy not in feature_map:
raise ValueError('invalid PHY') raise InvalidArgumentError('invalid PHY')
return self.supports_le_features(feature_map[phy]) return self.supports_le_features(feature_map[phy])
@@ -2271,6 +2405,10 @@ class Device(CompositeEventEmitter):
def supports_le_extended_advertising(self): def supports_le_extended_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING) return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING)
@property
def supports_le_periodic_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_PERIODIC_ADVERTISING)
async def start_advertising( async def start_advertising(
self, self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
@@ -2323,7 +2461,7 @@ class Device(CompositeEventEmitter):
# Decide what peer address to use # Decide what peer address to use
if advertising_type.is_directed: if advertising_type.is_directed:
if target is None: if target is None:
raise ValueError('directed advertising requires a target') raise InvalidArgumentError('directed advertising requires a target')
peer_address = target peer_address = target
else: else:
peer_address = Address.ANY peer_address = Address.ANY
@@ -2430,7 +2568,7 @@ class Device(CompositeEventEmitter):
and advertising_data and advertising_data
and scan_response_data and scan_response_data
): ):
raise ValueError( raise InvalidArgumentError(
"Extended advertisements can't have both data and scan \ "Extended advertisements can't have both data and scan \
response data" response data"
) )
@@ -2446,7 +2584,9 @@ class Device(CompositeEventEmitter):
if handle not in self.extended_advertising_sets if handle not in self.extended_advertising_sets
) )
except StopIteration as exc: except StopIteration as exc:
raise RuntimeError("all valid advertising handles already in use") from exc raise OutOfResourcesError(
"all valid advertising handles already in use"
) from exc
# Use the device's random address if a random address is needed but none was # Use the device's random address if a random address is needed but none was
# provided. # provided.
@@ -2545,14 +2685,14 @@ class Device(CompositeEventEmitter):
) -> None: ) -> None:
# Check that the arguments are legal # Check that the arguments are legal
if scan_interval < scan_window: if scan_interval < scan_window:
raise ValueError('scan_interval must be >= scan_window') raise InvalidArgumentError('scan_interval must be >= scan_window')
if ( if (
scan_interval < DEVICE_MIN_SCAN_INTERVAL scan_interval < DEVICE_MIN_SCAN_INTERVAL
or scan_interval > DEVICE_MAX_SCAN_INTERVAL or scan_interval > DEVICE_MAX_SCAN_INTERVAL
): ):
raise ValueError('scan_interval out of range') raise InvalidArgumentError('scan_interval out of range')
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW: if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
raise ValueError('scan_interval out of range') raise InvalidArgumentError('scan_interval out of range')
# Reset the accumulators # Reset the accumulators
self.advertisement_accumulators = {} self.advertisement_accumulators = {}
@@ -2580,7 +2720,7 @@ class Device(CompositeEventEmitter):
scanning_phy_count += 1 scanning_phy_count += 1
if scanning_phy_count == 0: if scanning_phy_count == 0:
raise ValueError('at least one scanning PHY must be enabled') raise InvalidArgumentError('at least one scanning PHY must be enabled')
await self.send_command( await self.send_command(
HCI_LE_Set_Extended_Scan_Parameters_Command( HCI_LE_Set_Extended_Scan_Parameters_Command(
@@ -2671,6 +2811,10 @@ class Device(CompositeEventEmitter):
sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT, sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT,
filter_duplicates: bool = False, filter_duplicates: bool = False,
) -> PeriodicAdvertisingSync: ) -> PeriodicAdvertisingSync:
# Check that the controller supports the feature.
if not self.supports_le_periodic_advertising:
raise NotSupportedError()
# Check that there isn't already an equivalent entry # Check that there isn't already an equivalent entry
if any( if any(
sync.advertiser_address == advertiser_address and sync.sid == sid sync.advertiser_address == advertiser_address and sync.sid == sid
@@ -2868,23 +3012,52 @@ class Device(CompositeEventEmitter):
] = None, ] = None,
own_address_type: int = OwnAddressType.RANDOM, own_address_type: int = OwnAddressType.RANDOM,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
always_resolve: bool = False,
) -> Connection: ) -> Connection:
''' '''
Request a connection to a peer. Request a connection to a peer.
When transport is BLE, this method cannot be called if there is already a
When the transport is BLE, this method cannot be called if there is already a
pending connection. pending connection.
connection_parameters_preferences: (BLE only, ignored for BR/EDR) Args:
* None: use the 1M PHY with default parameters peer_address:
* map: each entry has a PHY as key and a ConnectionParametersPreferences Address or name of the device to connect to.
object as value If a string is passed:
If the string is an address followed by a `@` suffix, the `always_resolve`
argument is implicitly set to True, so the connection is made to the
address after resolution.
If the string is any other address, the connection is made to that
address (with or without address resolution, depending on the
`always_resolve` argument).
For any other string, a scan for devices using that string as their name
is initiated, and a connection to the first matching device's address
is made. In that case, `always_resolve` is ignored.
own_address_type: (BLE only) connection_parameters_preferences:
(BLE only, ignored for BR/EDR)
* None: use the 1M PHY with default parameters
* map: each entry has a PHY as key and a ConnectionParametersPreferences
object as value
own_address_type:
(BLE only, ignored for BR/EDR)
OwnAddressType.RANDOM to use this device's random address, or
OwnAddressType.PUBLIC to use this device's public address.
timeout:
Maximum time to wait for a connection to be established, in seconds.
Pass None for an unlimited time.
always_resolve:
(BLE only, ignored for BR/EDR)
If True, always initiate a scan, resolving addresses, and connect to the
address that resolves to `peer_address`.
''' '''
# Check parameters # Check parameters
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT): if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
raise ValueError('invalid transport') raise InvalidArgumentError('invalid transport')
# Adjust the transport automatically if we need to # Adjust the transport automatically if we need to
if transport == BT_LE_TRANSPORT and not self.le_enabled: if transport == BT_LE_TRANSPORT and not self.le_enabled:
@@ -2898,11 +3071,19 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str): if isinstance(peer_address, str):
try: try:
peer_address = Address.from_string_for_transport( if transport == BT_LE_TRANSPORT and peer_address.endswith('@'):
peer_address, transport peer_address = Address.from_string_for_transport(
) peer_address[:-1], transport
except ValueError: )
always_resolve = True
logger.debug('forcing address resolution')
else:
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except (InvalidArgumentError, ValueError):
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
peer_address, transport peer_address, transport
@@ -2913,10 +3094,16 @@ class Device(CompositeEventEmitter):
transport == BT_BR_EDR_TRANSPORT transport == BT_BR_EDR_TRANSPORT
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
): ):
raise ValueError('BR/EDR addresses must be PUBLIC') raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, Address) assert isinstance(peer_address, Address)
if transport == BT_LE_TRANSPORT and always_resolve:
logger.debug('resolving address')
peer_address = await self.find_peer_by_identity_address(
peer_address
) # TODO: timeout
def on_connection(connection): def on_connection(connection):
if transport == BT_LE_TRANSPORT or ( if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address # match BR/EDR connection event against peer address
@@ -2964,7 +3151,7 @@ class Device(CompositeEventEmitter):
) )
) )
if not phys: if not phys:
raise ValueError('at least one supported PHY needed') raise InvalidArgumentError('at least one supported PHY needed')
phy_count = len(phys) phy_count = len(phys)
initiating_phys = phy_list_to_bits(phys) initiating_phys = phy_list_to_bits(phys)
@@ -3036,7 +3223,7 @@ class Device(CompositeEventEmitter):
) )
else: else:
if HCI_LE_1M_PHY not in connection_parameters_preferences: if HCI_LE_1M_PHY not in connection_parameters_preferences:
raise ValueError('1M PHY preferences required') raise InvalidArgumentError('1M PHY preferences required')
prefs = connection_parameters_preferences[HCI_LE_1M_PHY] prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
result = await self.send_command( result = await self.send_command(
@@ -3136,7 +3323,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str): if isinstance(peer_address, str):
try: try:
peer_address = Address(peer_address) peer_address = Address(peer_address)
except ValueError: except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
@@ -3146,7 +3333,7 @@ class Device(CompositeEventEmitter):
assert isinstance(peer_address, Address) assert isinstance(peer_address, Address)
if peer_address == Address.NIL: if peer_address == Address.NIL:
raise ValueError('accept on nil address') raise InvalidArgumentError('accept on nil address')
# Create a future so that we can wait for the request # Create a future so that we can wait for the request
pending_request_fut = asyncio.get_running_loop().create_future() pending_request_fut = asyncio.get_running_loop().create_future()
@@ -3259,7 +3446,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str): if isinstance(peer_address, str):
try: try:
peer_address = Address(peer_address) peer_address = Address(peer_address)
except ValueError: except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
@@ -3302,10 +3489,10 @@ class Device(CompositeEventEmitter):
async def set_data_length(self, connection, tx_octets, tx_time) -> None: async def set_data_length(self, connection, tx_octets, tx_time) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB: if tx_octets < 0x001B or tx_octets > 0x00FB:
raise ValueError('tx_octets must be between 0x001B and 0x00FB') raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
if tx_time < 0x0148 or tx_time > 0x4290: if tx_time < 0x0148 or tx_time > 0x4290:
raise ValueError('tx_time must be between 0x0148 and 0x4290') raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')
return await self.send_command( return await self.send_command(
HCI_LE_Set_Data_Length_Command( HCI_LE_Set_Data_Length_Command(
@@ -3418,15 +3605,26 @@ class Device(CompositeEventEmitter):
check_result=True, check_result=True,
) )
async def transfer_periodic_sync(
self, connection: Connection, sync_handle: int, service_data: int = 0
) -> None:
return await self.send_command(
HCI_LE_Periodic_Advertising_Sync_Transfer_Command(
connection_handle=connection.handle,
service_data=service_data,
sync_handle=sync_handle,
),
check_result=True,
)
async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT): async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT):
""" """
Scan for a peer with a give name and return its address and transport Scan for a peer with a given name and return its address.
""" """
# Create a future to wait for an address to be found # Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future() peer_address = asyncio.get_running_loop().create_future()
# Scan/inquire with event handlers to handle scan/inquiry results
def on_peer_found(address, ad_data): def on_peer_found(address, ad_data):
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
if local_name is None: if local_name is None:
@@ -3435,13 +3633,13 @@ class Device(CompositeEventEmitter):
if local_name.decode('utf-8') == name: if local_name.decode('utf-8') == name:
peer_address.set_result(address) peer_address.set_result(address)
handler = None listener = None
was_scanning = self.scanning was_scanning = self.scanning
was_discovering = self.discovering was_discovering = self.discovering
try: try:
if transport == BT_LE_TRANSPORT: if transport == BT_LE_TRANSPORT:
event_name = 'advertisement' event_name = 'advertisement'
handler = self.on( listener = self.on(
event_name, event_name,
lambda advertisement: on_peer_found( lambda advertisement: on_peer_found(
advertisement.address, advertisement.data advertisement.address, advertisement.data
@@ -3453,7 +3651,7 @@ class Device(CompositeEventEmitter):
elif transport == BT_BR_EDR_TRANSPORT: elif transport == BT_BR_EDR_TRANSPORT:
event_name = 'inquiry_result' event_name = 'inquiry_result'
handler = self.on( listener = self.on(
event_name, event_name,
lambda address, class_of_device, eir_data, rssi: on_peer_found( lambda address, class_of_device, eir_data, rssi: on_peer_found(
address, eir_data address, eir_data
@@ -3467,21 +3665,67 @@ class Device(CompositeEventEmitter):
return await self.abort_on('flush', peer_address) return await self.abort_on('flush', peer_address)
finally: finally:
if handler is not None: if listener is not None:
self.remove_listener(event_name, handler) self.remove_listener(event_name, listener)
if transport == BT_LE_TRANSPORT and not was_scanning: if transport == BT_LE_TRANSPORT and not was_scanning:
await self.stop_scanning() await self.stop_scanning()
elif transport == BT_BR_EDR_TRANSPORT and not was_discovering: elif transport == BT_BR_EDR_TRANSPORT and not was_discovering:
await self.stop_discovery() await self.stop_discovery()
async def find_peer_by_identity_address(self, identity_address: Address) -> Address:
"""
Scan for a peer with a resolvable address that can be resolved to a given
identity address.
"""
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()
def on_peer_found(address, _):
if address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching public address found for {address}')
peer_address.set_result(address)
return
if address.is_resolvable:
resolved_address = self.address_resolver.resolve(address)
if resolved_address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching identity found for {address}')
peer_address.set_result(address)
return
was_scanning = self.scanning
event_name = 'advertisement'
listener = None
try:
listener = self.on(
event_name,
lambda advertisement: on_peer_found(
advertisement.address, advertisement.data
),
)
if not self.scanning:
await self.start_scanning(filter_duplicates=True)
return await self.abort_on('flush', peer_address)
finally:
if listener is not None:
self.remove_listener(event_name, listener)
if not was_scanning:
await self.stop_scanning()
@property @property
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
return self.smp_manager.pairing_config_factory return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter @pairing_config_factory.setter
def pairing_config_factory( def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig] self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
) -> None: ) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory self.smp_manager.pairing_config_factory = pairing_config_factory
@@ -3580,7 +3824,7 @@ class Device(CompositeEventEmitter):
async def encrypt(self, connection, enable=True): async def encrypt(self, connection, enable=True):
if not enable and connection.transport == BT_LE_TRANSPORT: if not enable and connection.transport == BT_LE_TRANSPORT:
raise ValueError('`enable` parameter is classic only.') raise InvalidArgumentError('`enable` parameter is classic only.')
# Set up event handlers # Set up event handlers
pending_encryption = asyncio.get_running_loop().create_future() pending_encryption = asyncio.get_running_loop().create_future()
@@ -3599,11 +3843,12 @@ class Device(CompositeEventEmitter):
if connection.transport == BT_LE_TRANSPORT: if connection.transport == BT_LE_TRANSPORT:
# Look for a key in the key store # Look for a key in the key store
if self.keystore is None: if self.keystore is None:
raise RuntimeError('no key store') raise InvalidOperationError('no key store')
logger.debug(f'Looking up key for {connection.peer_address}')
keys = await self.keystore.get(str(connection.peer_address)) keys = await self.keystore.get(str(connection.peer_address))
if keys is None: if keys is None:
raise RuntimeError('keys not found in key store') raise InvalidOperationError('keys not found in key store')
if keys.ltk is not None: if keys.ltk is not None:
ltk = keys.ltk.value ltk = keys.ltk.value
@@ -3614,7 +3859,7 @@ class Device(CompositeEventEmitter):
rand = keys.ltk_central.rand rand = keys.ltk_central.rand
ediv = keys.ltk_central.ediv ediv = keys.ltk_central.ediv
else: else:
raise RuntimeError('no LTK found for peer') raise InvalidOperationError('no LTK found for peer')
if connection.role != HCI_CENTRAL_ROLE: if connection.role != HCI_CENTRAL_ROLE:
raise InvalidStateError('only centrals can start encryption') raise InvalidStateError('only centrals can start encryption')
@@ -3889,7 +4134,7 @@ class Device(CompositeEventEmitter):
return cis_link return cis_link
# Mypy believes this is reachable when context is an ExitStack. # Mypy believes this is reachable when context is an ExitStack.
raise InvalidStateError('Unreachable') raise UnreachableError()
# [LE only] # [LE only]
@experimental('Only for testing.') @experimental('Only for testing.')
@@ -4036,6 +4281,12 @@ class Device(CompositeEventEmitter):
else self.public_address else self.public_address
) )
if advertising_set.advertising_parameters.own_address_type in (
OwnAddressType.RANDOM,
OwnAddressType.PUBLIC,
):
connection.self_resolvable_address = None
# Setup auto-restart of the advertising set if needed. # Setup auto-restart of the advertising set if needed.
if advertising_set.auto_restart: if advertising_set.auto_restart:
connection.once( connection.once(
@@ -4071,12 +4322,23 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
def on_connection( def on_connection(
self, self,
connection_handle, connection_handle: int,
transport, transport: int,
peer_address, peer_address: Address,
role, self_resolvable_address: Optional[Address],
connection_parameters, peer_resolvable_address: Optional[Address],
): role: int,
connection_parameters: ConnectionParameters,
) -> None:
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if (
peer_resolvable_address == Address.ANY_RANDOM
or not peer_address.is_resolved
):
peer_resolvable_address = None
logger.debug( logger.debug(
f'*** Connection: [0x{connection_handle:04X}] ' f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}' f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
@@ -4097,17 +4359,18 @@ class Device(CompositeEventEmitter):
return return
# Resolve the peer address if we can if peer_resolvable_address is None:
peer_resolvable_address = None # Resolve the peer address if we can
if self.address_resolver: if self.address_resolver:
if peer_address.is_resolvable: if peer_address.is_resolvable:
resolved_address = self.address_resolver.resolve(peer_address) resolved_address = self.address_resolver.resolve(peer_address)
if resolved_address is not None: if resolved_address is not None:
logger.debug(f'*** Address resolved as {resolved_address}') logger.debug(f'*** Address resolved as {resolved_address}')
peer_resolvable_address = peer_address peer_resolvable_address = peer_address
peer_address = resolved_address peer_address = resolved_address
self_address = None self_address = None
own_address_type: Optional[int] = None
if role == HCI_CENTRAL_ROLE: if role == HCI_CENTRAL_ROLE:
own_address_type = self.connect_own_address_type own_address_type = self.connect_own_address_type
assert own_address_type is not None assert own_address_type is not None
@@ -4136,12 +4399,18 @@ class Device(CompositeEventEmitter):
else self.random_address else self.random_address
) )
# Some controllers may return local resolvable address even not using address
# generation offloading. Ignore the value to prevent SMP failure.
if own_address_type in (OwnAddressType.RANDOM, OwnAddressType.PUBLIC):
self_resolvable_address = None
# Create a connection. # Create a connection.
connection = Connection( connection = Connection(
self, self,
connection_handle, connection_handle,
transport, transport,
self_address, self_address,
self_resolvable_address,
peer_address, peer_address,
peer_resolvable_address, peer_resolvable_address,
role, role,
@@ -4152,9 +4421,10 @@ class Device(CompositeEventEmitter):
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser: if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
if self.legacy_advertiser.auto_restart: if self.legacy_advertiser.auto_restart:
advertiser = self.legacy_advertiser
connection.once( connection.once(
'disconnection', 'disconnection',
lambda _: self.abort_on('flush', self.legacy_advertiser.start()), lambda _: self.abort_on('flush', advertiser.start()),
) )
else: else:
self.legacy_advertiser = None self.legacy_advertiser = None
@@ -4377,7 +4647,7 @@ class Device(CompositeEventEmitter):
return await pairing_config.delegate.confirm(auto=True) return await pairing_config.delegate.confirm(auto=True)
async def na() -> bool: async def na() -> bool:
assert False, "N/A: unreachable" raise UnreachableError()
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6 # See Bluetooth spec @ Vol 3, Part C 5.2.2.6
methods = { methods = {
@@ -4838,5 +5108,6 @@ class Device(CompositeEventEmitter):
return ( return (
f'Device(name="{self.name}", ' f'Device(name="{self.name}", '
f'random_address="{self.random_address}", ' f'random_address="{self.random_address}", '
f'public_address="{self.public_address}")' f'public_address="{self.public_address}", '
f'static_address="{self.static_address}")'
) )

View File

@@ -33,6 +33,7 @@ from typing import Tuple
import weakref import weakref
from bumble import core
from bumble.hci import ( from bumble.hci import (
hci_vendor_command_op_code, hci_vendor_command_op_code,
STATUS_SPEC, STATUS_SPEC,
@@ -49,6 +50,10 @@ from bumble.drivers import common
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RtkFirmwareError(core.BaseBumbleError):
"""Error raised when RTK firmware initialization fails."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -208,15 +213,15 @@ class Firmware:
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77]) extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE): if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise ValueError("Firmware does not start with epatch signature") raise RtkFirmwareError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig): if not firmware.endswith(extension_sig):
raise ValueError("Firmware does not end with extension sig") raise RtkFirmwareError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header. # The firmware should start with a 14 byte header.
epatch_header_size = 14 epatch_header_size = 14
if len(firmware) < epatch_header_size: if len(firmware) < epatch_header_size:
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
# Look for the "project ID", starting from the end. # Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig) offset = len(firmware) - len(extension_sig)
@@ -230,7 +235,7 @@ class Firmware:
break break
if length == 0: if length == 0:
raise ValueError("Invalid 0-length instruction") raise RtkFirmwareError("Invalid 0-length instruction")
if opcode == 0 and length == 1: if opcode == 0 and length == 1:
project_id = firmware[offset - 1] project_id = firmware[offset - 1]
@@ -239,7 +244,7 @@ class Firmware:
offset -= length offset -= length
if project_id < 0: if project_id < 0:
raise ValueError("Project ID not found") raise RtkFirmwareError("Project ID not found")
self.project_id = project_id self.project_id = project_id
@@ -252,7 +257,7 @@ class Firmware:
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each) # <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each) # <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware): if epatch_header_size + 8 * num_patches > len(firmware):
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
chip_id_table_offset = epatch_header_size chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
@@ -266,7 +271,7 @@ class Firmware:
"<I", firmware, patch_offset_table_offset + 4 * patch_index "<I", firmware, patch_offset_table_offset + 4 * patch_index
) )
if patch_offset + patch_length > len(firmware): if patch_offset + patch_length > len(firmware):
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
# Get the SVN version for the patch # Get the SVN version for the patch
(svn_version,) = struct.unpack_from( (svn_version,) = struct.unpack_from(
@@ -645,7 +650,7 @@ class Driver(common.Driver):
): ):
return await self.download_for_rtl8723b() return await self.download_for_rtl8723b()
raise ValueError("ROM not supported") raise RtkFirmwareError("ROM not supported")
async def init_controller(self): async def init_controller(self):
await self.download_firmware() await self.download_firmware()

View File

@@ -39,7 +39,7 @@ from typing import (
) )
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID from bumble.core import BaseBumbleError, UUID
from bumble.att import Attribute, AttributeValue from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -320,6 +320,11 @@ def show_services(services: Iterable[Service]) -> None:
print(color(' ' + str(descriptor), 'green')) print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
class InvalidServiceError(BaseBumbleError):
"""The service is not compliant with the spec/profile"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Service(Attribute): class Service(Attribute):
''' '''

View File

@@ -253,7 +253,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: Type[TemplateService] SERVICE_CLASS: Type[TemplateService]
@classmethod @classmethod
def from_client(cls, client: Client) -> ProfileServiceProxy: def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -283,6 +283,8 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
connection.on('disconnection', self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None: def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
@@ -331,9 +333,9 @@ class Client:
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
if mtu < ATT_DEFAULT_MTU: if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF: if mtu > 0xFFFF:
raise ValueError('MTU must be <= 0xFFFF') raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:
@@ -405,7 +407,7 @@ class Client:
if not already_known: if not already_known:
self.services.append(service) self.services.append(service)
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]: async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.1 Discover All Primary Services See Vol 3, Part G - 4.4.1 Discover All Primary Services
''' '''
@@ -1072,6 +1074,10 @@ class Client:
) )
) )
def on_disconnection(self, _) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'

View File

@@ -31,6 +31,8 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
AdvertisingData, AdvertisingData,
DeviceClass, DeviceClass,
InvalidArgumentError,
InvalidPacketError,
ProtocolError, ProtocolError,
bit_flags_to_strings, bit_flags_to_strings,
name_or_number, name_or_number,
@@ -92,14 +94,14 @@ def map_class_of_device(class_of_device):
) )
def phy_list_to_bits(phys): def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int:
if phys is None: if phys is None:
return 0 return 0
phy_bits = 0 phy_bits = 0
for phy in phys: for phy in phys:
if phy not in HCI_LE_PHY_TYPE_TO_BIT: if phy not in HCI_LE_PHY_TYPE_TO_BIT:
raise ValueError('invalid PHY') raise InvalidArgumentError('invalid PHY')
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy] phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
return phy_bits return phy_bits
@@ -1553,7 +1555,7 @@ class HCI_Object:
new_offset, field_value = field_type(data, offset) new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset) return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}') raise InvalidArgumentError(f'unknown field type {field_type}')
@staticmethod @staticmethod
def dict_from_bytes(data, offset, fields): def dict_from_bytes(data, offset, fields):
@@ -1622,7 +1624,7 @@ class HCI_Object:
if 0 <= field_value <= 255: if 0 <= field_value <= 255:
field_bytes = bytes([field_value]) field_bytes = bytes([field_value])
else: else:
raise ValueError('value too large for *-typed field') raise InvalidArgumentError('value too large for *-typed field')
else: else:
field_bytes = bytes(field_value) field_bytes = bytes(field_value)
elif field_type == 'v': elif field_type == 'v':
@@ -1641,7 +1643,9 @@ class HCI_Object:
elif len(field_bytes) > field_type: elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type] field_bytes = field_bytes[:field_type]
else: else:
raise ValueError(f"don't know how to serialize type {type(field_value)}") raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes return field_bytes
@@ -1835,6 +1839,12 @@ class Address:
data, offset, Address.PUBLIC_DEVICE_ADDRESS data, offset, Address.PUBLIC_DEVICE_ADDRESS
) )
@staticmethod
def parse_random_address(data, offset):
return Address.parse_address_with_type(
data, offset, Address.RANDOM_DEVICE_ADDRESS
)
@staticmethod @staticmethod
def parse_address_with_type(data, offset, address_type): def parse_address_with_type(data, offset, address_type):
return offset + 6, Address(data[offset : offset + 6], address_type) return offset + 6, Address(data[offset : offset + 6], address_type)
@@ -1905,7 +1915,7 @@ class Address:
self.address_bytes = bytes(reversed(bytes.fromhex(address))) self.address_bytes = bytes(reversed(bytes.fromhex(address)))
if len(self.address_bytes) != 6: if len(self.address_bytes) != 6:
raise ValueError('invalid address length') raise InvalidArgumentError('invalid address length')
self.address_type = address_type self.address_type = address_type
@@ -1961,7 +1971,8 @@ class Address:
def __eq__(self, other): def __eq__(self, other):
return ( return (
self.address_bytes == other.address_bytes isinstance(other, Address)
and self.address_bytes == other.address_bytes
and self.is_public == other.is_public and self.is_public == other.is_public
) )
@@ -2108,7 +2119,7 @@ class HCI_Command(HCI_Packet):
op_code, length = struct.unpack_from('<HB', packet, 1) op_code, length = struct.unpack_from('<HB', packet, 1)
parameters = packet[4:] parameters = packet[4:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
# Look for a registered class # Look for a registered class
cls = HCI_Command.command_classes.get(op_code) cls = HCI_Command.command_classes.get(op_code)
@@ -4518,18 +4529,6 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
''' '''
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
[ [
@@ -4565,6 +4564,32 @@ class HCI_LE_Set_Privacy_Mode_Command(HCI_Command):
return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode) return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode)
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('connection_handle', 2),
],
)
class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command
'''
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
fields=[ fields=[
@@ -4807,7 +4832,7 @@ class HCI_Event(HCI_Packet):
length = packet[2] length = packet[2]
parameters = packet[3:] parameters = packet[3:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
cls: Any cls: Any
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
@@ -5174,8 +5199,8 @@ class HCI_LE_Data_Length_Change_Event(HCI_LE_Meta_Event):
), ),
('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type), ('peer_address', Address.parse_address_preceded_by_type),
('local_resolvable_private_address', Address.parse_address), ('local_resolvable_private_address', Address.parse_random_address),
('peer_resolvable_private_address', Address.parse_address), ('peer_resolvable_private_address', Address.parse_random_address),
('connection_interval', 2), ('connection_interval', 2),
('peripheral_latency', 2), ('peripheral_latency', 2),
('supervision_timeout', 2), ('supervision_timeout', 2),
@@ -6342,7 +6367,7 @@ class HCI_AclDataPacket(HCI_Packet):
bc_flag = (h >> 14) & 3 bc_flag = (h >> 14) & 3
data = packet[5:] data = packet[5:]
if len(data) != data_total_length: if len(data) != data_total_length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
return HCI_AclDataPacket( return HCI_AclDataPacket(
connection_handle, pb_flag, bc_flag, data_total_length, data connection_handle, pb_flag, bc_flag, data_total_length, data
) )
@@ -6390,7 +6415,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
packet_status = (h >> 12) & 0b11 packet_status = (h >> 12) & 0b11
data = packet[4:] data = packet[4:]
if len(data) != data_total_length: if len(data) != data_total_length:
raise ValueError( raise InvalidPacketError(
f'invalid packet length {len(data)} != {data_total_length}' f'invalid packet length {len(data)} != {data_total_length}'
) )
return HCI_SynchronousDataPacket( return HCI_SynchronousDataPacket(

View File

@@ -23,13 +23,12 @@ import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING from typing import Optional, Callable
from typing_extensions import override from typing_extensions import override
from bumble import l2cap, device from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from .hci import Address from bumble.hci import Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
try: try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect( channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM self.connection, HID_CONTROL_PSM
) )
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError: except ProtocolError:
logging.exception(f'L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None: async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel # Create a new L2CAP connection - interrupt channel
try: try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect( channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM self.connection, HID_INTERRUPT_PSM
) )
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError: except ProtocolError:
logging.exception(f'L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None: async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None: if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
@@ -334,17 +329,18 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04 ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF SUCCESS = 0xFF
@dataclass
class GetSetStatus: class GetSetStatus:
def __init__(self) -> None: data: bytes = b''
self.data = bytearray() status: int = 0
self.status = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None: def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE) super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override @override
def on_ctrl_pdu(self, pdu: bytes) -> None: def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -410,7 +406,6 @@ class Device(HID):
buffer_size = 0 buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size) ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE: if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS: elif ret.status == self.GetSetReturn.SUCCESS:
@@ -428,7 +423,9 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None: def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
self.get_report_cb = cb self.get_report_cb = cb
logger.debug("GetReport callback registered successfully") logger.debug("GetReport callback registered successfully")
@@ -442,7 +439,6 @@ class Device(HID):
report_data = pdu[2:] report_data = pdu[2:]
report_size = len(report_data) + 1 report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data) ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -453,7 +449,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb( def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None] self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
) -> None: ) -> None:
self.set_report_cb = cb self.set_report_cb = cb
logger.debug("SetReport callback registered successfully") logger.debug("SetReport callback registered successfully")
@@ -464,13 +460,12 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.get_protocol_cb() ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None: def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
self.get_protocol_cb = cb self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully") logger.debug("GetProtocol callback registered successfully")
@@ -480,13 +475,14 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.set_protocol_cb(pdu[0] & 0x01) ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None: def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
self.set_protocol_cb = cb self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully") logger.debug("SetProtocol callback registered successfully")

View File

@@ -772,6 +772,8 @@ class Host(AbortableEventEmitter):
event.connection_handle, event.connection_handle,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
event.peer_address, event.peer_address,
getattr(event, 'local_resolvable_private_address', None),
getattr(event, 'peer_resolvable_private_address', None),
event.role, event.role,
connection_parameters, connection_parameters,
) )
@@ -817,6 +819,8 @@ class Host(AbortableEventEmitter):
event.bd_addr, event.bd_addr,
None, None,
None, None,
None,
None,
) )
else: else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')

View File

@@ -41,7 +41,14 @@ from typing import (
from .utils import deprecated from .utils import deprecated
from .colors import color from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError from .core import (
BT_CENTRAL_ROLE,
InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
OutOfResourcesError,
ProtocolError,
)
from .hci import ( from .hci import (
HCI_LE_Connection_Update_Command, HCI_LE_Connection_Update_Command,
HCI_Object, HCI_Object,
@@ -189,17 +196,17 @@ class LeCreditBasedChannelSpec:
self.max_credits < 1 self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
): ):
raise ValueError('max credits out of range') raise InvalidArgumentError('max credits out of range')
if ( if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
): ):
raise ValueError('MTU out of range') raise InvalidArgumentError('MTU out of range')
if ( if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
): ):
raise ValueError('MPS out of range') raise InvalidArgumentError('MPS out of range')
class L2CAP_PDU: class L2CAP_PDU:
@@ -211,7 +218,7 @@ class L2CAP_PDU:
def from_bytes(data: bytes) -> L2CAP_PDU: def from_bytes(data: bytes) -> L2CAP_PDU:
# Check parameters # Check parameters
if len(data) < 4: if len(data) < 4:
raise ValueError('not enough data for L2CAP header') raise InvalidPacketError('not enough data for L2CAP header')
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0) _, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
l2cap_pdu_payload = data[4:] l2cap_pdu_payload = data[4:]
@@ -816,7 +823,7 @@ class ClassicChannel(EventEmitter):
# Check that we can start a new connection # Check that we can start a new connection
if self.connection_result: if self.connection_result:
raise RuntimeError('connection already pending') raise InvalidStateError('connection already pending')
self._change_state(self.State.WAIT_CONNECT_RSP) self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame( self.send_control_frame(
@@ -1129,7 +1136,7 @@ class LeCreditBasedChannel(EventEmitter):
# Check that we can start a new connection # Check that we can start a new connection
identifier = self.manager.next_identifier(self.connection) identifier = self.manager.next_identifier(self.connection)
if identifier in self.manager.le_coc_requests: if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests') raise InvalidStateError('too many concurrent connection requests')
self._change_state(self.State.CONNECTING) self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request( request = L2CAP_LE_Credit_Based_Connection_Request(
@@ -1516,7 +1523,7 @@ class ChannelManager:
if cid not in channels: if cid not in channels:
return cid return cid
raise RuntimeError('no free CID available') raise OutOfResourcesError('no free CID available')
@staticmethod @staticmethod
def find_free_le_cid(channels: Iterable[int]) -> int: def find_free_le_cid(channels: Iterable[int]) -> int:
@@ -1529,7 +1536,7 @@ class ChannelManager:
if cid not in channels: if cid not in channels:
return cid return cid
raise RuntimeError('no free CID') raise OutOfResourcesError('no free CID')
def next_identifier(self, connection: Connection) -> int: def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
@@ -1576,15 +1583,15 @@ class ChannelManager:
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if spec.psm in self.servers: if spec.psm in self.servers:
raise ValueError('PSM already in use') raise InvalidArgumentError('PSM already in use')
# Check that the PSM is valid # Check that the PSM is valid
if spec.psm % 2 == 0: if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)') raise InvalidArgumentError('invalid PSM (not odd)')
check = spec.psm >> 8 check = spec.psm >> 8
while check: while check:
if check % 2 != 0: if check % 2 != 0:
raise ValueError('invalid PSM') raise InvalidArgumentError('invalid PSM')
check >>= 8 check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
@@ -1626,7 +1633,7 @@ class ChannelManager:
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers: if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use') raise InvalidArgumentError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self, self,
@@ -2154,10 +2161,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {}) connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels) source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None: if spec.psm is None:
raise ValueError('PSM cannot be None') raise InvalidArgumentError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}') logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
@@ -2206,10 +2213,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {}) connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels) source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None: if spec.psm is None:
raise ValueError('PSM cannot be None') raise InvalidArgumentError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug( logger.debug(

View File

@@ -19,7 +19,12 @@ import logging
import asyncio import asyncio
from functools import partial from functools import partial
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT from bumble.core import (
BT_PERIPHERAL_ROLE,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError,
)
from bumble.colors import color from bumble.colors import color
from bumble.hci import ( from bumble.hci import (
Address, Address,
@@ -405,12 +410,12 @@ class RemoteLink:
def add_controller(self, controller): def add_controller(self, controller):
if self.controller: if self.controller:
raise ValueError('controller already set') raise InvalidStateError('controller already set')
self.controller = controller self.controller = controller
def remove_controller(self, controller): def remove_controller(self, controller):
if self.controller != controller: if self.controller != controller:
raise ValueError('controller mismatch') raise InvalidStateError('controller mismatch')
self.controller = None self.controller = None
def get_pending_connection(self): def get_pending_connection(self):

739
bumble/profiles/ascs.py Normal file
View File

@@ -0,0 +1,739 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Audio Stream Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import logging
import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
# -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# -----------------------------------------------------------------------------
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
# -----------------------------------------------------------------------------
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = (),
sink_ase_id: Sequence[int] = (),
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

View File

@@ -24,15 +24,12 @@ import enum
import struct import struct
import functools import functools
import logging import logging
from typing import Optional, List, Union, Type, Dict, Any, Tuple from typing import List
from typing_extensions import Self from typing_extensions import Self
from bumble import core from bumble import core
from bumble import colors
from bumble import device
from bumble import hci from bumble import hci
from bumble import gatt from bumble import gatt
from bumble import gatt_client
from bumble import utils from bumble import utils
from bumble.profiles import le_audio from bumble.profiles import le_audio
@@ -251,231 +248,6 @@ class AnnouncementType(utils.OpenIntEnum):
TARGETED = 0x01 TARGETED = 0x01
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
@dataclasses.dataclass @dataclasses.dataclass
class UnicastServerAdvertisingData: class UnicastServerAdvertisingData:
"""Advertising Data for ASCS.""" """Advertising Data for ASCS."""
@@ -683,51 +455,6 @@ class CodecSpecificConfiguration:
) )
@dataclasses.dataclass
class PacRecord:
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
# TODO: Parse Metadata
metadata: bytes = b''
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
metadata = data[offset : offset + metadata_size]
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(self.metadata)])
+ self.metadata
)
@dataclasses.dataclass @dataclasses.dataclass
class BroadcastAudioAnnouncement: class BroadcastAudioAnnouncement:
broadcast_id: int broadcast_id: int
@@ -819,603 +546,3 @@ class BasicAudioAnnouncement:
) )
return cls(presentation_delay, subgroups) return cls(presentation_delay, subgroups)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
# TODO: Parse this
metadata = b''
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = [],
sink_ase_id: Sequence[int] = [],
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

440
bumble/profiles/bass.py Normal file
View File

@@ -0,0 +1,440 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Broadcast Audio Scan Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import ClassVar, List, Optional, Sequence
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
from bumble import utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ApplicationError(utils.OpenIntEnum):
OPCODE_NOT_SUPPORTED = 0x80
INVALID_SOURCE_ID = 0x81
# -----------------------------------------------------------------------------
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
return bytes([len(subgroups)]) + b"".join(
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
+ subgroup.metadata
for subgroup in subgroups
)
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
num_subgroups = data[0]
offset = 1
subgroups = []
for _ in range(num_subgroups):
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
metadata_length = data[offset + 4]
metadata = data[offset + 5 : offset + 5 + metadata_length]
offset += 5 + metadata_length
subgroups.append(SubgroupInfo(bis_sync, metadata))
return subgroups
# -----------------------------------------------------------------------------
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
@dataclasses.dataclass
class SubgroupInfo:
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
bis_sync: int
metadata: bytes
class ControlPointOperation:
class OpCode(utils.OpenIntEnum):
REMOTE_SCAN_STOPPED = 0x00
REMOTE_SCAN_STARTED = 0x01
ADD_SOURCE = 0x02
MODIFY_SOURCE = 0x03
SET_BROADCAST_CODE = 0x04
REMOVE_SOURCE = 0x05
op_code: OpCode
parameters: bytes
@classmethod
def from_bytes(cls, data: bytes) -> ControlPointOperation:
op_code = data[0]
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
return RemoteScanStoppedOperation()
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
return RemoteScanStartedOperation()
if op_code == cls.OpCode.ADD_SOURCE:
return AddSourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.MODIFY_SOURCE:
return ModifySourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.SET_BROADCAST_CODE:
return SetBroadcastCodeOperation.from_parameters(data[1:])
if op_code == cls.OpCode.REMOVE_SOURCE:
return RemoveSourceOperation.from_parameters(data[1:])
raise core.InvalidArgumentError("invalid op code")
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
self.op_code = op_code
self.parameters = parameters
def __bytes__(self) -> bytes:
return bytes([self.op_code]) + self.parameters
class RemoteScanStoppedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
class RemoteScanStartedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
class AddSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
instance.parameters = parameters
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
parameters, 1
)[1]
instance.advertising_sid = parameters[7]
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
instance.subgroups = decode_subgroups(parameters[14:])
return instance
def __init__(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.ADD_SOURCE,
struct.pack(
"<B6sB3sBH",
advertiser_address.address_type,
bytes(advertiser_address),
advertising_sid,
broadcast_id.to_bytes(3, "little"),
pa_sync,
pa_interval,
)
+ encode_subgroups(subgroups),
)
self.advertiser_address = advertiser_address
self.advertising_sid = advertising_sid
self.broadcast_id = broadcast_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class ModifySourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
instance.subgroups = decode_subgroups(parameters[4:])
return instance
def __init__(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.MODIFY_SOURCE,
struct.pack("<BBH", source_id, pa_sync, pa_interval)
+ encode_subgroups(subgroups),
)
self.source_id = source_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class SetBroadcastCodeOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.broadcast_code = parameters[1:17]
return instance
def __init__(
self,
source_id: int,
broadcast_code: bytes,
) -> None:
super().__init__(
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
bytes([source_id]) + broadcast_code,
)
self.source_id = source_id
self.broadcast_code = broadcast_code
if len(self.broadcast_code) != 16:
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
class RemoveSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
return instance
def __init__(self, source_id: int) -> None:
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
self.source_id = source_id
@dataclasses.dataclass
class BroadcastReceiveState:
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
NOT_SYNCHRONIZED_TO_PA = 0x00
SYNCINFO_REQUEST = 0x01
SYNCHRONIZED_TO_PA = 0x02
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
NO_PAST = 0x04
class BigEncryption(utils.OpenIntEnum):
NOT_ENCRYPTED = 0x00
BROADCAST_CODE_REQUIRED = 0x01
DECRYPTING = 0x02
BAD_CODE = 0x03
source_id: int
source_address: hci.Address
source_adv_sid: int
broadcast_id: int
pa_sync_state: PeriodicAdvertisingSyncState
big_encryption: BigEncryption
bad_code: bytes
subgroups: List[SubgroupInfo]
@classmethod
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
if not data:
return None
source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8]
broadcast_id = int.from_bytes(data[9:12], "little")
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
big_encryption = cls.BigEncryption(data[13])
if big_encryption == cls.BigEncryption.BAD_CODE:
bad_code = data[14:30]
subgroups = decode_subgroups(data[30:])
else:
bad_code = b""
subgroups = decode_subgroups(data[14:])
return cls(
source_id,
source_address,
source_adv_sid,
broadcast_id,
pa_sync_state,
big_encryption,
bad_code,
subgroups,
)
def __bytes__(self) -> bytes:
return (
struct.pack(
"<BB6sB3sBB",
self.source_id,
self.source_address.address_type,
bytes(self.source_address),
self.source_adv_sid,
self.broadcast_id.to_bytes(3, "little"),
self.pa_sync_state,
self.big_encryption,
)
+ self.bad_code
+ encode_subgroups(self.subgroups)
)
# -----------------------------------------------------------------------------
class BroadcastAudioScanService(gatt.TemplateService):
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
def __init__(self):
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
write=self.on_broadcast_audio_scan_control_point_write
),
)
self.broadcast_receive_state_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
b"12", # TEST
)
super().__init__([self.battery_level_characteristic])
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
) -> None:
pass
# -----------------------------------------------------------------------------
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BroadcastAudioScanService
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [
gatt.DelegatedCharacteristicAdapter(
characteristic, decode=BroadcastReceiveState.from_bytes
)
for characteristic in characteristics
]
async def send_control_point_operation(
self, operation: ControlPointOperation
) -> None:
await self.broadcast_audio_scan_control_point.write_value(
bytes(operation), with_response=True
)
async def remote_scan_started(self) -> None:
await self.send_control_point_operation(RemoteScanStartedOperation())
async def remote_scan_stopped(self) -> None:
await self.send_control_point_operation(RemoteScanStoppedOperation())
async def add_source(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
AddSourceOperation(
advertiser_address,
advertising_sid,
broadcast_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def modify_source(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
ModifySourceOperation(
source_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def remove_source(self, source_id: int) -> None:
await self.send_control_point_operation(RemoveSourceOperation(source_id))

View File

@@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_member_rank: Optional[int] = None, set_member_rank: Optional[int] = None,
) -> None: ) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH: if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError( raise core.InvalidArgumentError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}' f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
) )
@@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
key = await connection.device.get_link_key(connection.peer_address) key = await connection.device.get_link_key(connection.peer_address)
if not key: if not key:
raise RuntimeError('LTK or LinkKey is not present') raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key) sirk_bytes = sef(key, self.set_identity_resolving_key)
@@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
'''Reads SIRK and decrypts if encrypted.''' '''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value() response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1: if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value') raise core.InvalidPacketError('Invalid SIRK value')
sirk_type = SirkType(response[0]) sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT: if sirk_type == SirkType.PLAINTEXT:
@@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
key = await device.get_link_key(connection.peer_address) key = await device.get_link_key(connection.peer_address)
if not key: if not key:
raise RuntimeError('LTK or LinkKey is not present') raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk = sef(key, response[1:]) sirk = sef(key, response[1:])

110
bumble/profiles/gap.py Normal file
View File

@@ -0,0 +1,110 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic Access Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Tuple, Union
from bumble.core import Appearance
from bumble.gatt import (
TemplateService,
Characteristic,
CharacteristicAdapter,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(TemplateService):
UUID = GATT_GENERIC_ACCESS_SERVICE
def __init__(
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
self.appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', appearance_int),
)
super().__init__(
[self.device_name_characteristic, self.appearance_characteristic]
)
# -----------------------------------------------------------------------------
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicAdapter]
appearance: Optional[DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC
):
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
else:
self.device_name = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_APPEARANCE_CHARACTERISTIC
):
self.appearance = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Appearance.from_int(
struct.unpack_from('<H', value, 0)[0],
),
)
else:
self.appearance = None

View File

@@ -19,6 +19,7 @@
from enum import IntEnum from enum import IntEnum
import struct import struct
from bumble import core
from ..gatt_client import ProfileServiceProxy from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error from ..att import ATT_Error
from ..gatt import ( from ..gatt import (
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
rr_intervals=None, rr_intervals=None,
): ):
if heart_rate < 0 or heart_rate > 0xFFFF: if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range') raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and ( if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF energy_expended < 0 or energy_expended > 0xFFFF
): ):
raise ValueError('energy_expended out of range') raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals: if rr_intervals:
for rr_interval in rr_intervals: for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range') raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected self.sensor_contact_detected = sensor_contact_detected

View File

@@ -17,33 +17,67 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import List import struct
from typing import List, Type
from typing_extensions import Self from typing_extensions import Self
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class Metadata: class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
again outside the lib.
'''
class Tag(utils.OpenIntEnum):
# fmt: off
PREFERRED_AUDIO_CONTEXTS = 0x01
STREAMING_AUDIO_CONTEXTS = 0x02
PROGRAM_INFO = 0x03
LANGUAGE = 0x04
CCID_LIST = 0x05
PARENTAL_RATING = 0x06
PROGRAM_INFO_URI = 0x07
AUDIO_ACTIVE_STATE = 0x08
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
ASSISTED_LISTENING_STREAM = 0x0A
BROADCAST_NAME = 0x0B
EXTENDED_METADATA = 0xFE
VENDOR_SPECIFIC = 0xFF
@dataclasses.dataclass @dataclasses.dataclass
class Entry: class Entry:
tag: int tag: Metadata.Tag
data: bytes data: bytes
entries: List[Entry] @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
def __bytes__(self) -> bytes:
return bytes([len(self.data) + 1, self.tag]) + self.data
entries: List[Entry] = dataclasses.field(default_factory=list)
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = [] entries = []
offset = 0 offset = 0
length = len(data) length = len(data)
while length >= 2: while offset < length:
entry_length = data[offset] entry_length = data[offset]
entry_tag = data[offset + 1] offset += 1
entry_data = data[offset + 2 : offset + 2 + entry_length - 1] entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
entries.append(cls.Entry(entry_tag, entry_data))
length -= entry_length
offset += entry_length offset += entry_length
return cls(entries) return cls(entries)
def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries])

448
bumble/profiles/mcp.py Normal file
View File

@@ -0,0 +1,448 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class PlayingOrder(utils.OpenIntEnum):
'''See Media Control Service 3.15. Playing Order.'''
SINGLE_ONCE = 0x01
SINGLE_REPEAT = 0x02
IN_ORDER_ONCE = 0x03
IN_ORDER_REPEAT = 0x04
OLDEST_ONCE = 0x05
OLDEST_REPEAT = 0x06
NEWEST_ONCE = 0x07
NEWEST_REPEAT = 0x08
SHUFFLE_ONCE = 0x09
SHUFFLE_REPEAT = 0x0A
class PlayingOrderSupported(enum.IntFlag):
'''See Media Control Service 3.16. Playing Orders Supported.'''
SINGLE_ONCE = 0x0001
SINGLE_REPEAT = 0x0002
IN_ORDER_ONCE = 0x0004
IN_ORDER_REPEAT = 0x0008
OLDEST_ONCE = 0x0010
OLDEST_REPEAT = 0x0020
NEWEST_ONCE = 0x0040
NEWEST_REPEAT = 0x0080
SHUFFLE_ONCE = 0x0100
SHUFFLE_REPEAT = 0x0200
class MediaState(utils.OpenIntEnum):
'''See Media Control Service 3.17. Media State.'''
INACTIVE = 0x00
PLAYING = 0x01
PAUSED = 0x02
SEEKING = 0x03
class MediaControlPointOpcode(utils.OpenIntEnum):
'''See Media Control Service 3.18. Media Control Point.'''
PLAY = 0x01
PAUSE = 0x02
FAST_REWIND = 0x03
FAST_FORWARD = 0x04
STOP = 0x05
MOVE_RELATIVE = 0x10
PREVIOUS_SEGMENT = 0x20
NEXT_SEGMENT = 0x21
FIRST_SEGMENT = 0x22
LAST_SEGMENT = 0x23
GOTO_SEGMENT = 0x24
PREVIOUS_TRACK = 0x30
NEXT_TRACK = 0x31
FIRST_TRACK = 0x32
LAST_TRACK = 0x33
GOTO_TRACK = 0x34
PREVIOUS_GROUP = 0x40
NEXT_GROUP = 0x41
FIRST_GROUP = 0x42
LAST_GROUP = 0x43
GOTO_GROUP = 0x44
class MediaControlPointResultCode(enum.IntFlag):
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
SUCCESS = 0x01
OPCODE_NOT_SUPPORTED = 0x02
MEDIA_PLAYER_INACTIVE = 0x03
COMMAND_CANNOT_BE_COMPLETED = 0x04
class MediaControlPointOpcodeSupported(enum.IntFlag):
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
PLAY = 0x00000001
PAUSE = 0x00000002
FAST_REWIND = 0x00000004
FAST_FORWARD = 0x00000008
STOP = 0x00000010
MOVE_RELATIVE = 0x00000020
PREVIOUS_SEGMENT = 0x00000040
NEXT_SEGMENT = 0x00000080
FIRST_SEGMENT = 0x00000100
LAST_SEGMENT = 0x00000200
GOTO_SEGMENT = 0x00000400
PREVIOUS_TRACK = 0x00000800
NEXT_TRACK = 0x00001000
FIRST_TRACK = 0x00002000
LAST_TRACK = 0x00004000
GOTO_TRACK = 0x00008000
PREVIOUS_GROUP = 0x00010000
NEXT_GROUP = 0x00020000
FIRST_GROUP = 0x00040000
LAST_GROUP = 0x00080000
GOTO_GROUP = 0x00100000
class SearchControlPointItemType(utils.OpenIntEnum):
'''See Media Control Service 3.20. Search Control Point.'''
TRACK_NAME = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
GROUP_NAME = 0x04
EARLIEST_YEAR = 0x05
LATEST_YEAR = 0x06
GENRE = 0x07
ONLY_TRACKS = 0x08
ONLY_GROUPS = 0x09
class ObjectType(utils.OpenIntEnum):
'''See Media Control Service 4.4.1. Object Type field.'''
TASK = 0
GROUP = 1
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectId(int):
'''See Media Control Service 4.4.2. Object ID field.'''
@classmethod
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(int.from_bytes(data, byteorder='little', signed=False))
def __bytes__(self) -> bytes:
return self.to_bytes(6, 'little')
@dataclasses.dataclass
class GroupObjectType:
'''See Media Control Service 4.4. Group Object Type.'''
object_type: ObjectType
object_id: ObjectId
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(
object_type=ObjectType(data[0]),
object_id=ObjectId.create_from_bytes(data[1:]),
)
def __bytes__(self) -> bytes:
return bytes([self.object_type]) + bytes(self.object_id)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class MediaControlService(gatt.TemplateService):
'''Media Control Service server implementation, only for testing currently.'''
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player',
)
self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_title_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_duration_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_position_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_state_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_control_point_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.content_control_id_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
super().__init__(
[
self.media_player_name_characteristic,
self.track_changed_characteristic,
self.track_title_characteristic,
self.track_duration_characteristic,
self.track_position_characteristic,
self.media_state_characteristic,
self.media_control_point_characteristic,
self.media_control_point_opcodes_supported_characteristic,
self.content_control_id_characteristic,
]
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(
connection,
self.media_control_point_characteristic,
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
)
class GenericMediaControlService(MediaControlService):
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class MediaControlServiceProxy(
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
):
SERVICE_CLASS = MediaControlService
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
}
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
track_changed: Optional[gatt_client.CharacteristicProxy] = None
track_title: Optional[gatt_client.CharacteristicProxy] = None
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
media_state: Optional[gatt_client.CharacteristicProxy] = None
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
None
)
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
utils.CompositeEventEmitter.__init__(self)
self.service_proxy = service_proxy
self.lock = asyncio.Lock()
self.media_control_point_notifications = asyncio.Queue()
for field, uuid in self._CHARACTERISTICS.items():
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, field, characteristics[0])
async def subscribe_characteristics(self) -> None:
if self.media_control_point:
await self.media_control_point.subscribe(self._on_media_control_point)
if self.media_state:
await self.media_state.subscribe(self._on_media_state)
if self.track_changed:
await self.track_changed.subscribe(self._on_track_changed)
if self.track_title:
await self.track_title.subscribe(self._on_track_title)
if self.track_duration:
await self.track_duration.subscribe(self._on_track_duration)
if self.track_position:
await self.track_position.subscribe(self._on_track_position)
async def write_control_point(
self, opcode: MediaControlPointOpcode
) -> MediaControlPointResultCode:
'''Writes a Media Control Point Opcode to peer and waits for the notification.
The write operation will be executed when there isn't other pending commands.
Args:
opcode: opcode defined in `MediaControlPointOpcode`.
Returns:
Response code provided in `MediaControlPointResultCode`
Raises:
InvalidOperationError: Server does not have Media Control Point Characteristic.
InvalidStateError: Server replies a notification with mismatched opcode.
'''
if not self.media_control_point:
raise core.InvalidOperationError("Peer does not have media control point")
async with self.lock:
await self.media_control_point.write_value(
bytes([opcode]),
with_response=False,
)
(
response_opcode,
response_code,
) = await self.media_control_point_notifications.get()
if response_opcode != opcode:
raise core.InvalidStateError(
f"Expected {opcode} notification, but get {response_opcode}"
)
return MediaControlPointResultCode(response_code)
def _on_media_control_point(self, data: bytes) -> None:
self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None:
del data
self.emit('track_changed')
def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
SERVICE_CLASS = GenericMediaControlService

210
bumble/profiles/pacs.py Normal file
View File

@@ -0,0 +1,210 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Published Audio Capabilities Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import Optional, Sequence, Union
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
from bumble.profiles import le_audio
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]

89
bumble/profiles/tmap.py Normal file
View File

@@ -0,0 +1,89 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Telephony and Media Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import logging
import struct
from bumble.gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Role(enum.IntFlag):
CALL_GATEWAY = 1 << 0
CALL_TERMINAL = 1 << 1
UNICAST_MEDIA_SENDER = 1 << 2
UNICAST_MEDIA_RECEIVER = 1 << 3
BROADCAST_MEDIA_SENDER = 1 << 4
BROADCAST_MEDIA_RECEIVER = 1 << 5
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioService(TemplateService):
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
def __init__(self, role: Role):
self.role_characteristic = Characteristic(
GATT_TMAP_ROLE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', int(role)),
)
super().__init__([self.role_characteristic])
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = TelephonyAndMediaAudioService
role: DelegatedCharacteristicAdapter
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0],
),
)

View File

@@ -36,7 +36,9 @@ from .core import (
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
InvalidArgumentError,
InvalidStateError, InvalidStateError,
InvalidPacketError,
ProtocolError, ProtocolError,
) )
@@ -335,7 +337,7 @@ class RFCOMM_Frame:
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information) frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs: if frame.fcs != fcs:
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}') logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch') raise InvalidPacketError('fcs mismatch')
return frame return frame
@@ -713,7 +715,7 @@ class DLC(EventEmitter):
# Automatically convert strings to bytes using UTF-8 # Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8') data = data.encode('utf-8')
else: else:
raise ValueError('write only accept bytes or strings') raise InvalidArgumentError('write only accept bytes or strings')
self.tx_buffer += data self.tx_buffer += data
self.drained.clear() self.drained.clear()

View File

@@ -23,7 +23,7 @@ from typing_extensions import Self
from . import core, l2cap from . import core, l2cap
from .colors import color from .colors import color
from .core import InvalidStateError from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
from .hci import HCI_Object, name_or_number, key_with_value from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -189,7 +189,9 @@ class DataElement:
self.bytes = None self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None: if value_size is None:
raise ValueError('integer types must have a value size specified') raise InvalidArgumentError(
'integer types must have a value size specified'
)
@staticmethod @staticmethod
def nil() -> DataElement: def nil() -> DataElement:
@@ -265,7 +267,7 @@ class DataElement:
if len(data) == 8: if len(data) == 8:
return struct.unpack('>Q', data)[0] return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}') raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def signed_integer_from_bytes(data): def signed_integer_from_bytes(data):
@@ -281,7 +283,7 @@ class DataElement:
if len(data) == 8: if len(data) == 8:
return struct.unpack('>q', data)[0] return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}') raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def list_from_bytes(data): def list_from_bytes(data):
@@ -354,7 +356,7 @@ class DataElement:
data = b'' data = b''
elif self.type == DataElement.UNSIGNED_INTEGER: elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0: if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative') raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1: if self.value_size == 1:
data = struct.pack('B', self.value) data = struct.pack('B', self.value)
@@ -365,7 +367,7 @@ class DataElement:
elif self.value_size == 8: elif self.value_size == 8:
data = struct.pack('>Q', self.value) data = struct.pack('>Q', self.value)
else: else:
raise ValueError('invalid value_size') raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER: elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1: if self.value_size == 1:
data = struct.pack('b', self.value) data = struct.pack('b', self.value)
@@ -376,7 +378,7 @@ class DataElement:
elif self.value_size == 8: elif self.value_size == 8:
data = struct.pack('>q', self.value) data = struct.pack('>q', self.value)
else: else:
raise ValueError('invalid value_size') raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID: elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value))) data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL: elif self.type == DataElement.URL:
@@ -392,7 +394,7 @@ class DataElement:
size_bytes = b'' size_bytes = b''
if self.type == DataElement.NIL: if self.type == DataElement.NIL:
if size != 0: if size != 0:
raise ValueError('NIL must be empty') raise InvalidArgumentError('NIL must be empty')
size_index = 0 size_index = 0
elif self.type in ( elif self.type in (
DataElement.UNSIGNED_INTEGER, DataElement.UNSIGNED_INTEGER,
@@ -410,7 +412,7 @@ class DataElement:
elif size == 16: elif size == 16:
size_index = 4 size_index = 4
else: else:
raise ValueError('invalid data size') raise InvalidArgumentError('invalid data size')
elif self.type in ( elif self.type in (
DataElement.TEXT_STRING, DataElement.TEXT_STRING,
DataElement.SEQUENCE, DataElement.SEQUENCE,
@@ -427,10 +429,10 @@ class DataElement:
size_index = 7 size_index = 7
size_bytes = struct.pack('>I', size) size_bytes = struct.pack('>I', size)
else: else:
raise ValueError('invalid data size') raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN: elif self.type == DataElement.BOOLEAN:
if size != 1: if size != 1:
raise ValueError('boolean must be 1 byte') raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0 size_index = 0
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data

View File

@@ -55,6 +55,7 @@ from .core import (
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
AdvertisingData, AdvertisingData,
InvalidArgumentError,
ProtocolError, ProtocolError,
name_or_number, name_or_number,
) )
@@ -766,8 +767,11 @@ class Session:
self.oob_data_flag = 0 if pairing_config.oob is None else 1 self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses # Set up addresses
self_address = connection.self_address self_address = connection.self_resolvable_address or connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address peer_address = connection.peer_resolvable_address or connection.peer_address
logger.debug(
f"pairing with self_address={self_address}, peer_address={peer_address}"
)
if self.is_initiator: if self.is_initiator:
self.ia = bytes(self_address) self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0 self.iat = 1 if self_address.is_random else 0
@@ -784,7 +788,7 @@ class Session:
self.peer_oob_data = pairing_config.oob.peer_data self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc: if pairing_config.sc:
if pairing_config.oob.our_context is None: if pairing_config.oob.our_context is None:
raise ValueError( raise InvalidArgumentError(
"oob pairing config requires a context when sc is True" "oob pairing config requires a context when sc is True"
) )
self.r = pairing_config.oob.our_context.r self.r = pairing_config.oob.our_context.r
@@ -793,7 +797,7 @@ class Session:
self.tk = pairing_config.oob.legacy_context.tk self.tk = pairing_config.oob.legacy_context.tk
else: else:
if pairing_config.oob.legacy_context is None: if pairing_config.oob.legacy_context is None:
raise ValueError( raise InvalidArgumentError(
"oob pairing config requires a legacy context when sc is False" "oob pairing config requires a legacy context when sc is False"
) )
self.r = bytes(16) self.r = bytes(16)
@@ -1074,11 +1078,19 @@ class Session:
) )
def send_identity_address_command(self) -> None: def send_identity_address_command(self) -> None:
identity_address = { if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
None: self.connection.self_address, identity_address = self.manager.device.public_address
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address, elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address, identity_address = self.manager.device.static_address
}[self.pairing_config.identity_address_type] else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
self.send_command( self.send_command(
SMP_Identity_Address_Information_Command( SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type, addr_type=identity_address.address_type,

View File

@@ -23,6 +23,7 @@ import datetime
from typing import BinaryIO, Generator from typing import BinaryIO, Generator
import os import os
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
""" """
if ':' not in spec: if ':' not in spec:
raise ValueError('snooper type prefix missing') raise core.InvalidArgumentError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1) snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop': if snooper_type == 'btsnoop':
if ':' not in snooper_args: if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing') raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1) io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file': if io_type == 'file':
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1 _SNOOPER_INSTANCE_COUNT -= 1
return return
raise ValueError(f'I/O type {io_type} not supported') raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found') raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')

View File

@@ -20,7 +20,7 @@ import logging
import os import os
from typing import Optional from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
from ..snoop import create_snooper from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -180,7 +180,13 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme') if scheme == 'unix':
from .unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
raise TransportSpecError('unknown transport scheme')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -20,7 +20,13 @@ import grpc.aio
from typing import Optional, Union from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport from .common import (
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
)
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
elif ':' in param: elif ':' in param:
server_host, server_port = param.split(':') server_host, server_port = param.split(':')
else: else:
raise ValueError('invalid parameter') raise TransportSpecError('invalid parameter')
# Connect to the gRPC server # Connect to the gRPC server
server_address = f'{server_host}:{server_port}' server_address = f'{server_host}:{server_port}'
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
service = VhciForwardingServiceStub(channel) service = VhciForwardingServiceStub(channel)
hci_device = HciDevice(service.attachVhci()) hci_device = HciDevice(service.attachVhci())
else: else:
raise ValueError('invalid mode') raise TransportSpecError('invalid mode')
# Create the transport object # Create the transport object
class EmulatorTransport(PumpedTransport): class EmulatorTransport(PumpedTransport):

View File

@@ -31,6 +31,8 @@ from .common import (
PumpedPacketSource, PumpedPacketSource,
PumpedPacketSink, PumpedPacketSink,
Transport, Transport,
TransportSpecError,
TransportInitError,
) )
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str] server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport: ) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise TransportSpecError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:
server_host = 'localhost' server_host = 'localhost'
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
instance_number = 0 if options is None else int(options.get('instance', '0')) instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number) server_port = find_grpc_port(instance_number)
if not server_port: if not server_port:
raise RuntimeError('gRPC server port not found') raise TransportInitError('gRPC server port not found')
# Connect to the gRPC server # Connect to the gRPC server
server_address = f'{server_host}:{server_port}' server_address = f'{server_host}:{server_port}'
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
if response_type == 'error': if response_type == 'error':
logger.warning(f'received error: {response.error}') logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error) raise TransportInitError(response.error)
if response_type == 'hci_packet': if response_type == 'hci_packet':
return ( return (
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
+ response.hci_packet.packet + response.hci_packet.packet
) )
raise ValueError('unsupported response type') raise TransportSpecError('unsupported response type')
async def write(self, packet): async def write(self, packet):
await self.hci_device.write( await self.hci_device.write(
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
options: Dict[str, str] = {} options: Dict[str, str] = {}
for param in params[params_offset:]: for param in params[params_offset:]:
if '=' not in param: if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>') raise TransportSpecError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=') option_name, option_value = param.split('=')
options[option_name] = option_value options[option_name] = option_value
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
) )
if mode == 'controller': if mode == 'controller':
if host is None: if host is None:
raise ValueError('<host>:<port> missing') raise TransportSpecError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options) return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option') raise TransportSpecError('invalid mode option')

View File

@@ -23,6 +23,7 @@ import logging
import io import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import core
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble.snoop import Snooper from bumble.snoop import Snooper
@@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Errors # Errors
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TransportLostError(Exception): class TransportLostError(core.BaseBumbleError, RuntimeError):
""" """The Transport has been lost/disconnected."""
The Transport has been lost/disconnected.
"""
class TransportInitError(core.BaseBumbleError, RuntimeError):
"""Error raised when the transport cannot be initialized."""
class TransportSpecError(core.BaseBumbleError, ValueError):
"""Error raised when the transport spec is invalid."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -132,7 +139,9 @@ class PacketParser:
packet_type packet_type
) or self.extended_packet_info.get(packet_type) ) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}') raise core.InvalidPacketError(
f'invalid packet type {packet_type}'
)
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
@@ -178,19 +187,19 @@ class PacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size) header = self.source.read(header_size)
if len(header) != header_size: if len(header) != header_size:
raise ValueError('packet too short') raise core.InvalidPacketError('packet too short')
# Read the body # Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length) body = self.source.read(body_length)
if len(body) != body_length: if len(body) != body_length:
raise ValueError('packet too short') raise core.InvalidPacketError('packet too short')
return packet_type + header + body return packet_type + header + body
@@ -211,7 +220,7 @@ class AsyncPacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -239,26 +248,28 @@ class AsyncPipeSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ParserSource: class BaseSource:
""" """
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
terminated: asyncio.Future[None] terminated: asyncio.Future[None]
parser: PacketParser sink: Optional[TransportSink]
def __init__(self) -> None: def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
self.sink = None
def set_packet_sink(self, sink: TransportSink) -> None: def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink) self.sink = sink
def on_transport_lost(self) -> None: def on_transport_lost(self) -> None:
self.terminated.set_result(None) if not self.terminated.done():
if self.parser.sink: self.terminated.set_result(None)
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost() if self.sink:
if hasattr(self.sink, 'on_transport_lost'):
self.sink.on_transport_lost()
async def wait_for_termination(self) -> None: async def wait_for_termination(self) -> None:
""" """
@@ -271,6 +282,23 @@ class ParserSource:
pass pass
# -----------------------------------------------------------------------------
class ParserSource(BaseSource):
"""
Base class for sources that use an HCI parser.
"""
parser: PacketParser
def __init__(self) -> None:
super().__init__()
self.parser = PacketParser()
def set_packet_sink(self, sink: TransportSink) -> None:
super().set_packet_sink(sink)
self.parser.set_packet_sink(sink)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
@@ -420,7 +448,7 @@ class SnoopingTransport(Transport):
return SnoopingTransport( return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
) )
raise RuntimeError('unexpected code path') # Satisfy the type checker raise core.UnreachableError() # Satisfy the type checker
class Source: class Source:
sink: TransportSink sink: TransportSink

View File

@@ -29,7 +29,7 @@ from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource from .common import Transport, ParserSource, TransportInitError
from .. import hci from .. import hci
from ..colors import color from ..colors import color
@@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
device = None device = None
if device is None: if device is None:
raise ValueError('device not found') raise TransportInitError('device not found')
logger.debug(f'USB Device: {device}') logger.debug(f'USB Device: {device}')
# Power Cycle the device # Power Cycle the device

56
bumble/transport/unix.py Normal file
View File

@@ -0,0 +1,56 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_unix_client_transport(spec: str) -> Transport:
'''Open a UNIX socket client transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
class UnixPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.on_transport_lost()
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
(
unix_transport,
packet_source,
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink)

View File

@@ -15,19 +15,18 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import threading import threading
import collections
import ctypes import ctypes
import platform import platform
import usb1 import usb1
from bumble.transport.common import Transport, ParserSource from bumble.transport.common import Transport, BaseSource, TransportInitError
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -115,13 +114,17 @@ async def open_usb_transport(spec: str) -> Transport:
self.device = device self.device = device
self.acl_out = acl_out self.acl_out = acl_out
self.acl_out_transfer = device.getTransfer() self.acl_out_transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent self.acl_out_transfer_ready = asyncio.Semaphore(1)
self.packets: asyncio.Queue[bytes] = (
asyncio.Queue()
) # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue_task = None
self.cancel_done = self.loop.create_future() self.cancel_done = self.loop.create_future()
self.closed = False self.closed = False
def start(self): def start(self):
pass self.queue_task = asyncio.create_task(self.process_queue())
def on_packet(self, packet): def on_packet(self, packet):
# Ignore packets if we're closed # Ignore packets if we're closed
@@ -133,62 +136,64 @@ async def open_usb_transport(spec: str) -> Transport:
return return
# Queue the packet # Queue the packet
self.packets.append(packet) self.packets.put_nowait(packet)
if len(self.packets) == 1:
# The queue was previously empty, re-prime the pump
self.process_queue()
def transfer_callback(self, transfer): def transfer_callback(self, transfer):
self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
status = transfer.getStatus() status = transfer.getStatus()
# pylint: disable=no-member # pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.on_packet_sent)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else: return
if status != usb1.TRANSFER_COMPLETED:
logger.warning( logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red') color(f'!!! OUT transfer not completed: status={status}', 'red')
) )
def on_packet_sent(self): async def process_queue(self):
if self.packets: while True:
self.packets.popleft() # Wait for a packet to transfer.
self.process_queue() packet = await self.packets.get()
def process_queue(self): # Wait until we can start a transfer.
if len(self.packets) == 0: await self.acl_out_transfer_ready.acquire()
return # Nothing to do
packet = self.packets[0] # Transfer the packet.
packet_type = packet[0] packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk( self.acl_out_transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback self.acl_out, packet[1:], callback=self.transfer_callback
) )
self.acl_out_transfer.submit() self.acl_out_transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl( self.acl_out_transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0, 0,
0, 0,
0, 0,
packet[1:], packet[1:],
callback=self.transfer_callback, callback=self.transfer_callback,
) )
self.acl_out_transfer.submit() self.acl_out_transfer.submit()
else: else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red')) logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
def close(self): def close(self):
self.closed = True self.closed = True
if self.queue_task:
self.queue_task.cancel()
async def terminate(self): async def terminate(self):
if not self.closed: if not self.closed:
self.close() self.close()
# Empty the packet queue so that we don't send any more data # Empty the packet queue so that we don't send any more data
self.packets.clear() while not self.packets.empty():
self.packets.get_nowait()
# If we have a transfer in flight, cancel it # If we have a transfer in flight, cancel it
if self.acl_out_transfer.isSubmitted(): if self.acl_out_transfer.isSubmitted():
@@ -203,7 +208,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError: except usb1.USBError:
logger.debug('OUT transfer likely already completed') logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, BaseSource):
def __init__(self, device, metadata, acl_in, events_in): def __init__(self, device, metadata, acl_in, events_in):
super().__init__() super().__init__()
self.device = device self.device = device
@@ -280,7 +285,13 @@ async def open_usb_transport(spec: str) -> Transport:
packet = await self.queue.get() packet = await self.queue.get()
except asyncio.CancelledError: except asyncio.CancelledError:
return return
self.parser.feed_data(packet) if self.sink:
try:
self.sink.on_packet(packet)
except Exception as error:
logger.exception(
color(f'!!! Exception in sink.on_packet: {error}', 'red')
)
def close(self): def close(self):
self.closed = True self.closed = True
@@ -442,7 +453,7 @@ async def open_usb_transport(spec: str) -> Transport:
if found is None: if found is None:
context.close() context.close()
raise ValueError('device not found') raise TransportInitError('device not found')
logger.debug(f'USB Device: {found}') logger.debug(f'USB Device: {found}')
@@ -507,7 +518,7 @@ async def open_usb_transport(spec: str) -> Transport:
endpoints = find_endpoints(found) endpoints = find_endpoints(found)
if endpoints is None: if endpoints is None:
raise ValueError('no compatible interface found for device') raise TransportInitError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints (configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug( logger.debug(
f'selected endpoints: configuration={configuration}, ' f'selected endpoints: configuration={configuration}, '

BIN
docs/images/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -0,0 +1,7 @@
{
"name": "Bumble",
"address": "F0:F1:F2:F3:F4:F5",
"keystore": "JsonKeyStore",
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
"le_privacy_enabled": true
}

View File

@@ -3,5 +3,6 @@
"keystore": "JsonKeyStore", "keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA", "address": "F0:F1:F2:F3:F4:FA",
"class_of_device": 2376708, "class_of_device": 2376708,
"cis_enabled": true,
"advertising_interval": 100 "advertising_interval": 100
} }

83
examples/mcp_server.html Normal file
View File

@@ -0,0 +1,83 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble LEA Media Control Client</span>
</div>
</nav>
<br>
<div class="container">
<label class="form-label">Server Port</label>
<div class="input-group mb-3">
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
</div>
<button class="btn btn-primary" onclick="send_opcode(0x01)">Play</button>
<button class="btn btn-primary" onclick="send_opcode(0x02)">Pause</button>
<button class="btn btn-primary" onclick="send_opcode(0x03)">Fast Rewind</button>
<button class="btn btn-primary" onclick="send_opcode(0x04)">Fast Forward</button>
<button class="btn btn-primary" onclick="send_opcode(0x05)">Stop</button>
</br></br>
<button class="btn btn-primary" onclick="send_opcode(0x30)">Previous Track</button>
<button class="btn btn-primary" onclick="send_opcode(0x31)">Next Track</button>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let portInput = document.getElementById("port")
let log = document.getElementById("log")
let socket
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
log.textContent += 'OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = (event) => {
log.textContent += `<-- ${event.data}\n`
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
let jsonMessage = JSON.stringify(message)
log.textContent += `--> ${jsonMessage}\n`
socket.send(jsonMessage)
} else {
log.textContent += 'NOT CONNECTED\n'
}
}
function send_opcode(opcode) {
send({ 'opcode': opcode })
}
</script>
</div>
</body>
</html>

View File

@@ -21,7 +21,7 @@ import os
import logging import logging
import json import json
import websockets import websockets
from bumble.colors import color import struct
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -30,9 +30,7 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE, BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID, BT_HIDP_PROTOCOL_ID,
UUID,
) )
from bumble.hci import Address
from bumble.hid import ( from bumble.hid import (
Device as HID_Device, Device as HID_Device,
HID_CONTROL_PSM, HID_CONTROL_PSM,
@@ -40,20 +38,17 @@ from bumble.hid import (
Message, Message,
) )
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client,
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT, SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
) )
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices # SDP attributes for Bluetooth HID devices
@@ -430,7 +425,7 @@ deviceData = DeviceData()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def keyboard_device(hid_device): async def keyboard_device(hid_device: HID_Device):
# Start a Websocket server to receive events from a web page # Start a Websocket server to receive events from a web page
async def serve(websocket, _path): async def serve(websocket, _path):
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
# limiting x and y values within logical max and min range # limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x)) x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y)) y = max(log_min, min(log_max, y))
x_cord = x.to_bytes(signed=True) deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
y_cord = y.to_bytes(signed=True) ">bb", x, y
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord )
hid_device.send_data(deviceData.mouseData) hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
pass pass
@@ -515,7 +510,9 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes): def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}') print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int): def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus() retValue = hid_device.GetSetStatus()
print( print(
"GET_REPORT report_id: " "GET_REPORT report_id: "
@@ -555,8 +552,7 @@ async def main() -> None:
def on_set_report_cb( def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes report_id: int, report_type: int, report_size: int, data: bytes
): ) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
print( print(
"SET_REPORT report_id: " "SET_REPORT report_id: "
+ str(report_id) + str(report_id)
@@ -568,33 +564,33 @@ async def main() -> None:
+ str(data) + str(data)
) )
if report_type == Message.ReportType.FEATURE_REPORT: if report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT: elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData): if report_id == 1 and report_size != len(deviceData.keyboardData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData): elif report_id == 2 and report_size != len(deviceData.mouseData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3: elif report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
else: else:
retValue.status = hid_device.GetSetReturn.SUCCESS status = HID_Device.GetSetReturn.SUCCESS
else: else:
retValue.status = hid_device.GetSetReturn.SUCCESS status = HID_Device.GetSetReturn.SUCCESS
return retValue return HID_Device.GetSetStatus(status=status)
def on_get_protocol_cb(): def on_get_protocol_cb() -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus() return HID_Device.GetSetStatus(
retValue.data = protocol_mode.to_bytes() data=bytes([protocol_mode]),
retValue.status = hid_device.GetSetReturn.SUCCESS status=hid_device.GetSetReturn.SUCCESS,
return retValue )
def on_set_protocol_cb(protocol: int): def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL. # We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}") print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST return HID_Device.GetSetStatus(
return retValue status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
def on_virtual_cable_unplug_cb(): def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug') print('Received Virtual Cable Unplug')

194
examples/run_mcp_client.py Normal file
View File

@@ -0,0 +1,194 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
import websockets
import json
from bumble.core import AdvertisingData
from bumble.device import (
Device,
AdvertisingParameters,
AdvertisingEventProperties,
Connection,
Peer,
)
from bumble.hci import (
CodecID,
CodingFormat,
OwnAddressType,
)
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import (
CodecSpecificCapabilities,
ContextType,
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
UnicastServerAdvertisingData,
)
from bumble.profiles.mcp import (
MediaControlServiceProxy,
GenericMediaControlServiceProxy,
MediaState,
MediaControlPointOpcode,
)
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link
from typing import Optional
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_mcp_client.py <config-file>' '<transport-spec-for-device>')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
# Add "placeholder" services to enable Android LEA features.
device.add_service(
PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED,
available_source_context=ContextType.PROHIBITED,
supported_sink_context=ContextType.MEDIA,
available_sink_context=ContextType.MEDIA,
sink_audio_locations=(
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
),
sink_pac=[
PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_16000
| SupportedSamplingFrequency.FREQ_32000
| SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=0,
max_octets_per_codec_frame=320,
supported_max_codec_frames_per_sdu=2,
),
),
],
)
)
device.add_service(AudioStreamControlService(device, sink_ase_id=[1]))
ws: Optional[websockets.WebSocketServerProtocol] = None
mcp: Optional[MediaControlServiceProxy] = None
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(UnicastServerAdvertisingData())
await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(
advertising_event_properties=AdvertisingEventProperties(),
own_address_type=OwnAddressType.RANDOM,
primary_advertising_interval_max=100,
primary_advertising_interval_min=100,
),
advertising_data=advertising_data,
auto_restart=True,
)
def on_media_state(media_state: MediaState) -> None:
if ws:
asyncio.create_task(
ws.send(json.dumps({'media_state': media_state.name}))
)
def on_track_title(title: str) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'title': title})))
def on_track_duration(duration: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'duration': duration})))
def on_track_position(position: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'position': position})))
def on_connection(connection: Connection) -> None:
async def on_connection_async():
async with Peer(connection) as peer:
nonlocal mcp
mcp = peer.create_service_proxy(MediaControlServiceProxy)
if not mcp:
mcp = peer.create_service_proxy(GenericMediaControlServiceProxy)
mcp.on('media_state', on_media_state)
mcp.on('track_title', on_track_title)
mcp.on('track_duration', on_track_duration)
mcp.on('track_position', on_track_position)
await mcp.subscribe_characteristics()
connection.abort_on('disconnection', on_connection_async())
device.on('connection', on_connection)
async def serve(websocket: websockets.WebSocketServerProtocol, _path):
nonlocal ws
ws = websocket
async for message in websocket:
request = json.loads(message)
if mcp:
await mcp.write_control_point(
MediaControlPointOpcode(request['opcode'])
)
ws = None
await websockets.serve(serve, 'localhost', 8989)
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -34,8 +34,8 @@ from bumble.hci import (
CodingFormat, CodingFormat,
HCI_IsoDataPacket, HCI_IsoDataPacket,
) )
from bumble.profiles.ascs import AseStateMachine, AudioStreamControlService
from bumble.profiles.bap import ( from bumble.profiles.bap import (
AseStateMachine,
UnicastServerAdvertisingData, UnicastServerAdvertisingData,
CodecSpecificConfiguration, CodecSpecificConfiguration,
CodecSpecificCapabilities, CodecSpecificCapabilities,
@@ -43,13 +43,10 @@ from bumble.profiles.bap import (
AudioLocation, AudioLocation,
SupportedSamplingFrequency, SupportedSamplingFrequency,
SupportedFrameDuration, SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
) )
from bumble.profiles.cap import CommonAudioServiceService from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link

View File

@@ -30,6 +30,7 @@ from bumble.hci import (
CodingFormat, CodingFormat,
OwnAddressType, OwnAddressType,
) )
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import ( from bumble.profiles.bap import (
UnicastServerAdvertisingData, UnicastServerAdvertisingData,
CodecSpecificCapabilities, CodecSpecificCapabilities,
@@ -37,10 +38,8 @@ from bumble.profiles.bap import (
AudioLocation, AudioLocation,
SupportedSamplingFrequency, SupportedSamplingFrequency,
SupportedFrameDuration, SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
) )
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.profiles.cap import CommonAudioServiceService from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.vcp import VolumeControlService from bumble.profiles.vcp import VolumeControlService

View File

@@ -142,7 +142,7 @@ class MainActivity : ComponentActivity() {
::runRfcommClient, ::runRfcommClient,
::runRfcommServer, ::runRfcommServer,
::runL2capClient, ::runL2capClient,
::runL2capServer ::runL2capServer,
) )
} }
@@ -166,6 +166,8 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> runRfcommServer() "rfcomm-server" -> runRfcommServer()
"l2cap-client" -> runL2capClient() "l2cap-client" -> runL2capClient()
"l2cap-server" -> runL2capServer() "l2cap-server" -> runL2capServer()
"scan-start" -> runScan(true)
"stop-start" -> runScan(false)
} }
} }
} }
@@ -190,6 +192,11 @@ class MainActivity : ComponentActivity() {
l2capServer?.run() l2capServer?.run()
} }
private fun runScan(startScan: Boolean) {
val scan = bluetoothAdapter?.let { Scan(it) }
scan?.run(startScan)
}
@SuppressLint("MissingPermission") @SuppressLint("MissingPermission")
fun becomeDiscoverable() { fun becomeDiscoverable() {
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE) val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
@@ -206,7 +213,7 @@ fun MainView(
runRfcommClient: () -> Unit, runRfcommClient: () -> Unit,
runRfcommServer: () -> Unit, runRfcommServer: () -> Unit,
runL2capClient: () -> Unit, runL2capClient: () -> Unit,
runL2capServer: () -> Unit runL2capServer: () -> Unit,
) { ) {
BTBenchTheme { BTBenchTheme {
val scrollState = rememberScrollState() val scrollState = rememberScrollState()

View File

@@ -0,0 +1,38 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.le.ScanCallback
import android.bluetooth.le.ScanResult
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.scan")
class Scan(val bluetoothAdapter: BluetoothAdapter) {
@SuppressLint("MissingPermission")
fun run(startScan: Boolean) {
var bluetoothLeScanner = bluetoothAdapter.bluetoothLeScanner
val scanCallback = object : ScanCallback() {
override fun onScanResult(callbackType: Int, result: ScanResult?) {
super.onScanResult(callbackType, result)
val device: BluetoothDevice? = result?.device
val deviceName = device?.name ?: "Unknown"
val deviceAddress = device?.address ?: "Unknown"
Log.info("Device found: $deviceName ($deviceAddress)")
}
override fun onScanFailed(errorCode: Int) {
// Handle scan failure
Log.warning("Scan failed with error code: $errorCode")
}
}
if (startScan) {
bluetoothLeScanner?.startScan(scanCallback)
} else {
bluetoothLeScanner?.stopScan(scanCallback)
}
}
}

View File

@@ -23,8 +23,9 @@ import logging
from bumble import device from bumble import device
from bumble.hci import CodecID, CodingFormat from bumble.hci import CodecID, CodingFormat
from bumble.profiles.bap import ( from bumble.profiles.ascs import (
AudioLocation, AudioStreamControlService,
AudioStreamControlServiceProxy,
AseStateMachine, AseStateMachine,
ASE_Operation, ASE_Operation,
ASE_Config_Codec, ASE_Config_Codec,
@@ -35,6 +36,9 @@ from bumble.profiles.bap import (
ASE_Receiver_Stop_Ready, ASE_Receiver_Stop_Ready,
ASE_Release, ASE_Release,
ASE_Update_Metadata, ASE_Update_Metadata,
)
from bumble.profiles.bap import (
AudioLocation,
SupportedFrameDuration, SupportedFrameDuration,
SupportedSamplingFrequency, SupportedSamplingFrequency,
SamplingFrequency, SamplingFrequency,
@@ -42,12 +46,13 @@ from bumble.profiles.bap import (
CodecSpecificCapabilities, CodecSpecificCapabilities,
CodecSpecificConfiguration, CodecSpecificConfiguration,
ContextType, ContextType,
)
from bumble.profiles.pacs import (
PacRecord, PacRecord,
AudioStreamControlService,
AudioStreamControlServiceProxy,
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy, PublishedAudioCapabilitiesServiceProxy,
) )
from bumble.profiles.le_audio import Metadata
from tests.test_utils import TwoDevices from tests.test_utils import TwoDevices
@@ -97,7 +102,7 @@ def test_pac_record() -> None:
pac_record = PacRecord( pac_record = PacRecord(
coding_format=CodingFormat(CodecID.LC3), coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=cap, codec_specific_capabilities=cap,
metadata=b'', metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
) )
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
@@ -142,7 +147,7 @@ def test_ASE_Config_QOS() -> None:
def test_ASE_Enable() -> None: def test_ASE_Enable() -> None:
operation = ASE_Enable( operation = ASE_Enable(
ase_id=[1, 2], ase_id=[1, 2],
metadata=[b'foo', b'bar'], metadata=[b'', b''],
) )
basic_check(operation) basic_check(operation)
@@ -151,7 +156,7 @@ def test_ASE_Enable() -> None:
def test_ASE_Update_Metadata() -> None: def test_ASE_Update_Metadata() -> None:
operation = ASE_Update_Metadata( operation = ASE_Update_Metadata(
ase_id=[1, 2], ase_id=[1, 2],
metadata=[b'foo', b'bar'], metadata=[b'', b''],
) )
basic_check(operation) basic_check(operation)

146
tests/bass_test.py Normal file
View File

@@ -0,0 +1,146 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from bumble import hci
from bumble.profiles import bass
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def basic_operation_check(operation: bass.ControlPointOperation) -> None:
serialized = bytes(operation)
parsed = bass.ControlPointOperation.from_bytes(serialized)
assert bytes(parsed) == serialized
# -----------------------------------------------------------------------------
def test_operations() -> None:
op1 = bass.RemoteScanStoppedOperation()
basic_operation_check(op1)
op2 = bass.RemoteScanStartedOperation()
basic_operation_check(op2)
op3 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(),
)
basic_operation_check(op3)
op4 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(
bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')),
bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')),
),
)
basic_operation_check(op4)
op5 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(),
)
basic_operation_check(op5)
op6 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
),
)
basic_operation_check(op6)
op7 = bass.SetBroadcastCodeOperation(
7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf')
)
basic_operation_check(op7)
op8 = bass.RemoveSourceOperation(7)
basic_operation_check(op8)
# -----------------------------------------------------------------------------
def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None:
serialized = bytes(brs)
parsed = bass.BroadcastReceiveState.from_bytes(serialized)
assert parsed is not None
assert bytes(parsed) == serialized
def test_broadcast_receive_state() -> None:
subgroups = [
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
]
brs1 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.DECRYPTING,
b'',
subgroups,
)
basic_broadcast_receive_state_check(brs1)
brs2 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.BAD_CODE,
bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'),
subgroups,
)
basic_broadcast_receive_state_check(brs2)
# -----------------------------------------------------------------------------
async def run():
test_operations()
test_broadcast_receive_state()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -276,34 +276,6 @@ async def test_legacy_advertising():
assert not device.is_advertising assert not device.is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_legacy_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
# Start advertising
await device.start_advertising()
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
await async_barrier()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.parametrize( @pytest.mark.parametrize(
'auto_restart,', 'auto_restart,',
@@ -318,6 +290,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
@@ -367,6 +341,8 @@ async def test_extended_advertising_connection(own_address_type):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
@@ -407,6 +383,8 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
@@ -558,6 +536,16 @@ async def test_cis_setup_failure():
await asyncio.wait_for(cis_create_task, _TIMEOUT) await asyncio.wait_for(cis_create_task, _TIMEOUT)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():
devices = TwoDevices()
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
await devices[0].power_on()
assert devices[0].static_address != Address.ANY_RANDOM
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_gatt_services_with_gas(): def test_gatt_services_with_gas():
device = Device(host=Host(None, None)) device = Device(host=Host(None, None))

View File

@@ -879,6 +879,57 @@ async def test_unsubscribe():
mock1.assert_called_once_with(ANY, False, False) mock1.assert_called_once_with(ANY, False, False)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_discover_all():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
descriptor1 = Descriptor('2902', 'READABLE,WRITEABLE')
descriptor2 = Descriptor('AAAA', 'READABLE,WRITEABLE')
characteristic2 = Characteristic(
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
descriptors=[descriptor1, descriptor2],
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic1, characteristic2],
)
service2 = Service('1111', [])
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_all()
assert len(peer.gatt_client.services) == 3
# service 1800 gets added automatically
assert peer.gatt_client.services[0].uuid == UUID('1800')
assert peer.gatt_client.services[1].uuid == service1.uuid
assert peer.gatt_client.services[2].uuid == service2.uuid
s = peer.get_services_by_uuid(service1.uuid)
assert len(s) == 1
assert len(s[0].characteristics) == 2
c = peer.get_characteristics_by_uuid(uuid=characteristic2.uuid, service=s[0])
assert len(c) == 1
assert len(c[0].descriptors) == 2
s = peer.get_services_by_uuid(service2.uuid)
assert len(s) == 1
assert len(s[0].characteristics) == 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mtu_exchange(): async def test_mtu_exchange():
@@ -1146,6 +1197,56 @@ def test_get_attribute_group():
) )
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_characteristics_by_uuid():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'1234',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
characteristic2 = Characteristic(
'5678',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
service1 = Service(
'ABCD',
[characteristic1, characteristic2],
)
service2 = Service(
'FFFF',
[characteristic1],
)
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0
s = peer.get_services_by_uuid(uuid=UUID('ABCD'))
assert len(s) == 1
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=s[0])
assert len(s) == 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())

View File

@@ -27,7 +27,6 @@ def test_import():
core, core,
crypto, crypto,
device, device,
gap,
hci, hci,
hfp, hfp,
host, host,
@@ -41,6 +40,22 @@ def test_import():
utils, utils,
) )
from bumble.profiles import (
ascs,
bap,
bass,
battery_service,
cap,
csip,
device_information_service,
gap,
heart_rate_service,
le_audio,
pacs,
pbp,
vcp,
)
assert att assert att
assert bridge assert bridge
assert company_ids assert company_ids
@@ -48,7 +63,6 @@ def test_import():
assert core assert core
assert crypto assert crypto
assert device assert device
assert gap
assert hci assert hci
assert hfp assert hfp
assert host assert host
@@ -61,6 +75,20 @@ def test_import():
assert transport assert transport
assert utils assert utils
assert ascs
assert bap
assert bass
assert battery_service
assert cap
assert csip
assert device_information_service
assert gap
assert heart_rate_service
assert le_audio
assert pacs
assert pbp
assert vcp
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_app_imports(): def test_app_imports():

39
tests/le_audio_test.py Normal file
View File

@@ -0,0 +1,39 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.profiles import le_audio
def test_parse_metadata():
metadata = le_audio.Metadata(
entries=[
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PROGRAM_INFO,
data=b'',
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
data=bytes([0, 0]),
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
data=bytes([1, 2]),
),
]
)
assert le_audio.Metadata.from_bytes(bytes(metadata)) == metadata

132
tests/mcp_test.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import dataclasses
import pytest
import pytest_asyncio
import struct
import logging
from bumble import device
from bumble.profiles import mcp
from tests.test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
TIMEOUT = 0.1
@dataclasses.dataclass
class GmcsContext:
devices: TwoDevices
client: mcp.GenericMediaControlServiceProxy
server: mcp.GenericMediaControlService
# -----------------------------------------------------------------------------
@pytest_asyncio.fixture
async def gmcs_context():
devices = TwoDevices()
server = mcp.GenericMediaControlService()
devices[0].add_service(server)
await devices.setup_connection()
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
client = await peer.discover_service_and_create_proxy(
mcp.GenericMediaControlServiceProxy
)
await client.subscribe_characteristics()
return GmcsContext(devices=devices, server=server, client=client)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_media_state(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('media_state', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.media_state_characteristic,
value=bytes([mcp.MediaState.PLAYING]),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == mcp.MediaState.PLAYING
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_title(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_title', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_title_characteristic,
value="My Song".encode(),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == "My Song"
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_duration(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_duration', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_duration_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_position(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_position', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_position_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_media_control_point(gmcs_context):
assert (
await asyncio.wait_for(
gmcs_context.client.write_control_point(mcp.MediaControlPointOpcode.PAUSE),
TIMEOUT,
)
) == mcp.MediaControlPointResultCode.SUCCESS

View File

@@ -17,13 +17,17 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import pytest import pytest
from unittest import mock
from bumble import smp from bumble import smp
from bumble import pairing
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1 from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address from bumble.hci import Address
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Device
from typing import Optional
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # pylint: disable=invalid-name
@@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected) assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'identity_address_type, public_address, random_address, expected_identity_address',
[
(
None,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
None,
Address.ANY,
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.PUBLIC,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.RANDOM,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
],
)
@pytest.mark.asyncio
async def test_send_identity_address_command(
identity_address_type: Optional[pairing.PairingConfig.AddressType],
public_address: Address,
random_address: Address,
expected_identity_address: Address,
):
device = Device()
device.public_address = public_address
device.static_address = random_address
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
with mock.patch.object(session, 'send_command') as mock_method:
session.send_identity_address_command()
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_ecc() test_ecc()

View File

@@ -51,3 +51,5 @@ Example:
NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory. NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory.
Make a copy of the built `.whl` file in the `web` directory. Make a copy of the built `.whl` file in the `web` directory.
Tip: During web developement, disable caching. [Chrome](https://stackoverflow.com/a/7000899]) / [Firefiox](https://stackoverflow.com/a/289771)

1
web/favicon.ico Symbolic link
View File

@@ -0,0 +1 @@
../docs/images/favicon.ico

View File

@@ -15,12 +15,21 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import pyee
from bumble.device import Device from bumble.device import Device
from bumble.hci import HCI_Reset_Command from bumble.hci import HCI_Reset_Command
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Scanner: class Scanner(pyee.EventEmitter):
"""
Scanner web app
Emitted events:
update: Emit when new `ScanEntry` are available.
"""
class ScanEntry: class ScanEntry:
def __init__(self, advertisement): def __init__(self, advertisement):
self.address = advertisement.address.to_string(False) self.address = advertisement.address.to_string(False)
@@ -39,13 +48,12 @@ class Scanner:
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
) )
self.scan_entries = {} self.scan_entries = {}
self.listeners = {}
self.device.on('advertisement', self.on_advertisement) self.device.on('advertisement', self.on_advertisement)
async def start(self): async def start(self):
print('### Starting Scanner') print('### Starting Scanner')
self.scan_entries = {} self.scan_entries = {}
self.emit_update() self.emit('update', self.scan_entries)
await self.device.power_on() await self.device.power_on()
await self.device.start_scanning() await self.device.start_scanning()
print('### Scanner started') print('### Scanner started')
@@ -56,16 +64,9 @@ class Scanner:
await self.device.power_off() await self.device.power_off()
print('### Scanner stopped') print('### Scanner stopped')
def emit_update(self):
if listener := self.listeners.get('update'):
listener(list(self.scan_entries.values()))
def on(self, event_name, listener):
self.listeners[event_name] = listener
def on_advertisement(self, advertisement): def on_advertisement(self, advertisement):
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement) self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
self.emit_update() self.emit('update', self.scan_entries)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------