Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
5e4055bb6b fix attribute discovery and display 2024-08-11 09:33:53 -07:00
80 changed files with 2252 additions and 8456 deletions

View File

@@ -40,11 +40,4 @@ jobs:
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
if: always()
run: cat rootcanal.log
- name: Upload Mobly logs
if: always()
uses: actions/upload-artifact@v3
with:
name: mobly-logs
path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -17,11 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import logging
import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
from typing import cast, Dict, Optional, Tuple
import click
import pyee
@@ -33,7 +32,6 @@ import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
@@ -48,16 +46,14 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
# -----------------------------------------------------------------------------
# Scan For Broadcasts
# Discover Broadcasts
# -----------------------------------------------------------------------------
class BroadcastScanner(pyee.EventEmitter):
class BroadcastDiscoverer:
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str
@@ -83,6 +79,22 @@ class BroadcastScanner(pyee.EventEmitter):
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
@@ -127,8 +139,6 @@ class BroadcastScanner(pyee.EventEmitter):
data,
)
self.emit('update')
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
@@ -217,12 +227,13 @@ class BroadcastScanner(pyee.EventEmitter):
)
def on_sync_establishment(self) -> None:
self.emit('sync_establishment')
self.establishment_timeout_task.cancel()
self.emit('change')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('sync_loss')
self.emit('change')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
@@ -257,21 +268,37 @@ class BroadcastScanner(pyee.EventEmitter):
filter_duplicates: bool,
sync_timeout: float,
):
super().__init__()
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement)
async def start(self) -> None:
async def run(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
async def stop(self) -> None:
await self.device.stop_scanning()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if (
@@ -284,6 +311,7 @@ class BroadcastScanner(pyee.EventEmitter):
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
self.refresh()
return
bumble.utils.AsyncRunner.spawn(
@@ -303,318 +331,41 @@ class BroadcastScanner(pyee.EventEmitter):
name,
periodic_advertising_sync,
)
broadcast.on('change', self.refresh)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.emit('new_broadcast', broadcast)
self.status_message = color(
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner:
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
self.scanner.on('update', self.refresh)
self.status_message = ''
async def start(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.scanner.start()
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
)
broadcast.on('change', self.refresh)
broadcast.on('update', self.refresh)
self.refresh()
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
f'-Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.scanner.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
@contextlib.asynccontextmanager
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
async def run_discover_broadcasts(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device_config = bumble.device.DeviceConfiguration(
name=AURACAST_DEFAULT_DEVICE_NAME,
address=AURACAST_DEFAULT_DEVICE_ADDRESS,
keystore='JsonKeyStore',
)
device = bumble.device.Device.from_config_with_hci(
device_config,
device = bumble.device.Device.with_hci(
AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS,
hci_source,
hci_sink,
)
await device.power_on()
yield device
async def find_broadcast_by_name(
device: bumble.device.Device, name: Optional[str]
) -> BroadcastScanner.Broadcast:
result = asyncio.get_running_loop().create_future()
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
if broadcast.basic_audio_announcement and not result.done():
print(color('Broadcast basic audio announcement received', 'green'))
result.set_result(broadcast)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if name is None or broadcast.name == name:
print(color('Broadcast found:', 'green'), broadcast.name)
broadcast.on('change', lambda: on_broadcast_change(broadcast))
return
print(color(f'Skipping broadcast {broadcast.name}'))
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
broadcast = await result
await scanner.stop()
return broadcast
async def run_scan(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
await scanner.start()
await asyncio.get_running_loop().create_future()
async def run_assist(
broadcast_name: Optional[str],
source_id: Optional[int],
command: str,
transport: str,
address: str,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
# Connect to the server
print(f'=== Connecting to {address}...')
connection = await device.connect(address)
peer = bumble.device.Peer(connection)
print(f'=== Connected to {peer}')
print("+++ Encrypting connection...")
await peer.connection.encrypt()
print("+++ Connection encrypted")
# Request a larger MTU
mtu = AURACAST_DEFAULT_ATT_MTU
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
await peer.request_mtu(mtu)
# Get the BASS service
bass = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy
)
# Check that the service was found
if not bass:
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return
# Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
try:
await broadcast_receive_state.subscribe(
lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
)
)
except bumble.core.ProtocolError as error:
print(
color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
'red',
),
error,
)
value = await broadcast_receive_state.read_value()
print(
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
)
if command == 'monitor-state':
await peer.sustain()
return
if command == 'add-source':
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source(
broadcast.sync.advertiser_address,
broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
# Initiate a PA Sync Transfer
await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning.
await bass.remote_scan_stopped()
await peer.sustain()
return
if command == 'modify-source':
if source_id is None:
print(color('!!! modify-source requires --source-id'))
return
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Modify the source
print(
color('Modifying source:', 'blue'),
source_id,
)
await bass.modify_source(
source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
await peer.sustain()
return
if command == 'remove-source':
if source_id is None:
print(color('!!! remove-source requires --source-id'))
return
# Remove the source
print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id)
await peer.sustain()
return
print(color(f'!!! invalid command {command}'))
async def run_pair(transport: str, address: str) -> None:
async with create_device(transport) as device:
# Connect to the server
print(f'=== Connecting to {address}...')
async with device.connect_as_gatt(address) as peer:
print(f'=== Connected to {peer}')
print("+++ Initiating pairing...")
await peer.connection.pair()
print("+++ Paired")
def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
except bumble.core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError
):
message = bumble.profiles.bass.ApplicationError(error.error_code).name
else:
message = str(error)
print(
color('!!! An error occurred while executing the command:', 'red'), message
)
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run()
await hci_source.terminated
# -----------------------------------------------------------------------------
@@ -628,7 +379,7 @@ def auracast(
ctx.ensure_object(dict)
@auracast.command('scan')
@auracast.command('discover-broadcasts')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@@ -636,50 +387,14 @@ def auracast(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
default=5.0,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def scan(ctx, filter_duplicates, sync_timeout, transport):
"""Scan for public broadcasts"""
run_async(run_scan(filter_duplicates, sync_timeout, transport))
@auracast.command('assist')
@click.option(
'--broadcast-name',
metavar='BROADCAST_NAME',
help='Broadcast Name to tune to',
)
@click.option(
'--source-id',
metavar='SOURCE_ID',
type=int,
help='Source ID (for remove-source command)',
)
@click.option(
'--command',
type=click.Choice(
['monitor-state', 'add-source', 'modify-source', 'remove-source']
),
required=True,
)
@click.argument('transport')
@click.argument('address')
@click.pass_context
def assist(ctx, broadcast_name, source_id, command, transport, address):
"""Scan for broadcasts on behalf of a audio server"""
run_async(run_assist(broadcast_name, source_id, command, transport, address))
@auracast.command('pair')
@click.argument('transport')
@click.argument('address')
@click.pass_context
def pair(ctx, transport, address):
"""Pair with an audio server"""
run_async(run_pair(transport, address))
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
def main():

View File

@@ -168,6 +168,7 @@ class ConsoleApp:
'remote-services': None,
'local-values': None,
'remote-values': None,
'remote-attributes': None,
},
'filter': {
'address': None,
@@ -216,6 +217,7 @@ class ConsoleApp:
)
self.local_values_text = FormattedTextControl()
self.remote_values_text = FormattedTextControl()
self.remote_attributes_text = FormattedTextControl()
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = []
@@ -242,6 +244,12 @@ class ConsoleApp:
Frame(Window(self.remote_values_text), title='Remote Values'),
filter=Condition(lambda: self.top_tab == 'remote-values'),
),
ConditionalContainer(
Frame(
Window(self.remote_attributes_text), title='Remote Attributes'
),
filter=Condition(lambda: self.top_tab == 'remote-attributes'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log'),
@@ -504,6 +512,8 @@ class ConsoleApp:
await self.connected_peer.discover_all()
self.append_to_output('Service Discovery done!')
self.show_remote_services(self.connected_peer.services)
async def discover_attributes(self):
if not self.connected_peer:
self.show_error('not connected')
@@ -514,7 +524,7 @@ class ConsoleApp:
attributes = await self.connected_peer.discover_attributes()
self.append_to_output(f'discovered {len(attributes)} attributes...')
self.show_attributes(attributes)
await self.show_remote_attributes(attributes)
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
if not self.connected_peer:
@@ -659,7 +669,6 @@ class ConsoleApp:
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except bumble.core.TimeoutError:
self.show_error('connection timed out')
@@ -730,19 +739,20 @@ class ConsoleApp:
'remote-services',
'local-values',
'remote-values',
'remote-attributes',
}:
self.top_tab = params[0]
self.ui.invalidate()
while self.top_tab == 'local-values':
await self.do_show_local_values()
await self.show_local_values()
await asyncio.sleep(1)
while self.top_tab == 'remote-values':
await self.do_show_remote_values()
await self.show_remote_values()
await asyncio.sleep(1)
async def do_show_local_values(self):
async def show_local_values(self):
prettytable = PrettyTable()
field_names = ["Service", "Characteristic", "Descriptor"]
@@ -797,7 +807,7 @@ class ConsoleApp:
self.local_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def do_show_remote_values(self):
async def show_remote_values(self):
prettytable = PrettyTable(
field_names=[
"Connection",
@@ -831,6 +841,23 @@ class ConsoleApp:
self.remote_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def show_remote_attributes(self, attributes):
lines = []
for attribute in attributes:
lines.append(('ansimagenta', str(attribute) + "\n"))
try:
value = await attribute.read_value()
lines.append(('ansicyan', value.hex() + "\n"))
except bumble.core.ProtocolError as error:
lines.append(("ansired", f"!!! Protocol Error ({error})\n"))
except bumble.core.TimeoutError:
lines.append(("ansired", "!!! Timeout\n"))
except Exception as error:
lines.append(("ansired", f"!!! Error ({error})\n"))
self.remote_attributes_text.text = lines
self.ui.invalidate()
async def do_get_phy(self, _):
if not self.connected_peer:
self.show_error('not connected')

View File

@@ -27,7 +27,6 @@ from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
CodecID,
LeFeature,
HCI_SUCCESS,
HCI_VERSION_NAMES,
@@ -51,8 +50,6 @@ from bumble.hci import (
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.host import Host
@@ -171,60 +168,6 @@ async def get_acl_flow_control_info(host: Host) -> None:
)
# -----------------------------------------------------------------------------
async def get_codecs_info(host: Host) -> None:
print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
)
print(color('Codecs:', 'yellow'))
for codec_id, transport in zip(
response.return_parameters.standard_codec_ids,
response.return_parameters.standard_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
for codec_id, transport in zip(
response.return_parameters.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
if not response.return_parameters.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
)
print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids:
codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
print('<<< connecting to HCI...')
@@ -277,9 +220,6 @@ async def async_main(latency_probes, transport):
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Get codec info
await get_codecs_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))

View File

@@ -1,230 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from typing import Callable, Iterable, Optional
import click
from bumble.core import ProtocolError
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def try_show(function: Callable, *args, **kwargs) -> None:
try:
await function(*args, **kwargs)
except ProtocolError as error:
print(color('ERROR:', 'red'), error)
# -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
# -----------------------------------------------------------------------------
async def show_gap_information(
gap_service: GenericAccessServiceProxy,
):
print(color('### Generic Access Profile', 'yellow'))
if gap_service.device_name:
print(
color(' Device Name:', 'green'),
await gap_service.device_name.read_value(),
)
if gap_service.appearance:
print(
color(' Appearance: ', 'green'),
await gap_service.appearance.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_information(
device_information_service: DeviceInformationServiceProxy,
):
print(color('### Device Information', 'yellow'))
if device_information_service.manufacturer_name:
print(
color(' Manufacturer Name:', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number:
print(
color(' Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number:
print(
color(' Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.firmware_revision:
print(
color(' Firmware Revision:', 'green'),
await device_information_service.firmware_revision.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_battery_level(
battery_service: BatteryServiceProxy,
):
print(color('### Battery Information', 'yellow'))
if battery_service.battery_level:
print(
color(' Battery Level:', 'green'),
await battery_service.battery_level.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_tmas(
tmas: TelephonyAndMediaAudioServiceProxy,
):
print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role:
print(
color(' Role:', 'green'),
await tmas.role.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
await peer.discover_services()
for service in peer.services:
await service.discover_characteristics()
print(color('=== Services ===', 'yellow'))
show_services(peer.services)
print()
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
await try_show(show_gap_information, gap_service)
if device_information_service := peer.create_service_proxy(
DeviceInformationServiceProxy
):
await try_show(show_device_information, device_information_service)
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
await try_show(show_battery_level, battery_service)
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas)
if done is not None:
done.set_result(None)
except asyncio.CancelledError:
print(color('!!! Operation canceled', 'red'))
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await show_device_info(Peer(connection), None)
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on(
'connection',
lambda connection: asyncio.create_task(
show_device_info(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
await done
# -----------------------------------------------------------------------------
@click.command()
@click.option('--device-config', help='Device configuration', type=click.Path())
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
@click.argument('transport')
@click.argument('address-or-name', required=False)
def main(device_config, encrypt, transport, address_or_name):
"""
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -75,15 +75,11 @@ async def async_main(device_config, encrypt, transport, address_or_name):
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await dump_gatt_db(Peer(connection), None)
else:

View File

@@ -33,6 +33,7 @@ import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
@@ -42,7 +43,7 @@ from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
@@ -56,8 +57,8 @@ logger = logging.getLogger(__name__)
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
@@ -78,8 +79,8 @@ def _sink_pac_record() -> pacs.PacRecord:
)
def _source_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
@@ -446,7 +447,7 @@ class Speaker:
)
self.device.add_service(
pacs.PublishedAudioCapabilitiesService(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
@@ -460,10 +461,10 @@ class Speaker:
)
)
ascs_service = ascs.AudioStreamControlService(
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs_service)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
@@ -478,13 +479,13 @@ class Speaker:
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
@@ -494,12 +495,12 @@ class Speaker:
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
if ase.state == ascs.AseStateMachine.State.STREAMING:
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == ascs.AudioRole.SOURCE:
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
@@ -515,10 +516,10 @@ class Speaker:
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == ascs.AudioRole.SOURCE:
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
@@ -531,7 +532,7 @@ class Speaker:
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs_service.ase_state_machines.values():
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()

View File

@@ -46,12 +46,6 @@ from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
@@ -241,10 +235,8 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', lambda keys: on_pairing(connection, keys))
connection.on(
'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
connection.on(
@@ -278,24 +270,19 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing(connection, keys):
def on_pairing(address, keys):
print(color('***-----------------------------------', 'cyan'))
print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
await asyncio.sleep(POST_PAIRING_DELAY)
await connection.disconnect()
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason):
def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
Waiter.instance.terminate()
@@ -306,7 +293,6 @@ async def pair(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -396,18 +382,11 @@ async def pair(
oob_contexts = None
# Set up a pairing config factory
if identity_address == 'public':
identity_address_type = PairingConfig.AddressType.PUBLIC
elif identity_address == 'random':
identity_address_type = PairingConfig.AddressType.RANDOM
else:
identity_address_type = None
device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
identity_address_type=identity_address_type,
delegate=Delegate(mode, connection, io, prompt),
)
@@ -478,10 +457,6 @@ class LogHandler(logging.Handler):
help='Enable CTKD',
show_default=True,
)
@click.option(
'--identity-address',
type=click.Choice(['random', 'public']),
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
@@ -518,7 +493,6 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -544,7 +518,6 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,

View File

@@ -1,608 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import asyncio.subprocess
import os
import logging
from typing import Optional, Union
import click
from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE,
AacFrame,
AacParser,
AacPacketSource,
AacMediaCodecInformation,
SbcFrame,
SbcParser,
SbcPacketSource,
SbcMediaCodecInformation,
OpusPacket,
OpusParser,
OpusPacketSource,
OpusMediaCodecInformation,
)
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.avdtp import (
find_avdtp_service_with_connection,
AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
MediaCodecCapabilities,
MediaPacketPump,
Protocol as AvdtpProtocol,
)
from bumble.colors import color
from bumble.core import (
AdvertisingData,
ConnectionError as BumbleConnectionError,
DeviceClass,
BT_BR_EDR_TRANSPORT,
)
from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant
from bumble.pairing import PairingConfig
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def a2dp_source_sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
# -----------------------------------------------------------------------------
async def sbc_codec_capabilities(read_function) -> MediaCodecCapabilities:
sbc_parser = SbcParser(read_function)
sbc_frame: SbcFrame
async for sbc_frame in sbc_parser.frames:
# We only need the first frame
print(color(f"SBC format: {sbc_frame}", "cyan"))
break
channel_mode = [
SbcMediaCodecInformation.ChannelMode.MONO,
SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL,
SbcMediaCodecInformation.ChannelMode.STEREO,
SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
][sbc_frame.channel_mode]
block_length = {
4: SbcMediaCodecInformation.BlockLength.BL_4,
8: SbcMediaCodecInformation.BlockLength.BL_8,
12: SbcMediaCodecInformation.BlockLength.BL_12,
16: SbcMediaCodecInformation.BlockLength.BL_16,
}[sbc_frame.block_count]
subbands = {
4: SbcMediaCodecInformation.Subbands.S_4,
8: SbcMediaCodecInformation.Subbands.S_8,
}[sbc_frame.subband_count]
allocation_method = [
SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
SbcMediaCodecInformation.AllocationMethod.SNR,
][sbc_frame.allocation_method]
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.from_int(
sbc_frame.sampling_frequency
),
channel_mode=channel_mode,
block_length=block_length,
subbands=subbands,
allocation_method=allocation_method,
minimum_bitpool_value=2,
maximum_bitpool_value=40,
),
)
# -----------------------------------------------------------------------------
async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities:
aac_parser = AacParser(read_function)
aac_frame: AacFrame
async for aac_frame in aac_parser.frames:
# We only need the first frame
print(color(f"AAC format: {aac_frame}", "cyan"))
break
sampling_frequency = AacMediaCodecInformation.SamplingFrequency.from_int(
aac_frame.sampling_frequency
)
channels = (
AacMediaCodecInformation.Channels.MONO
if aac_frame.channel_configuration == 1
else AacMediaCodecInformation.Channels.STEREO
)
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=sampling_frequency,
channels=channels,
vbr=1,
bitrate=128000,
),
)
# -----------------------------------------------------------------------------
async def opus_codec_capabilities(read_function) -> MediaCodecCapabilities:
opus_parser = OpusParser(read_function)
opus_packet: OpusPacket
async for opus_packet in opus_parser.packets:
# We only need the first packet
print(color(f"Opus format: {opus_packet}", "cyan"))
break
if opus_packet.channel_mode == OpusPacket.ChannelMode.MONO:
channel_mode = OpusMediaCodecInformation.ChannelMode.MONO
elif opus_packet.channel_mode == OpusPacket.ChannelMode.STEREO:
channel_mode = OpusMediaCodecInformation.ChannelMode.STEREO
else:
channel_mode = OpusMediaCodecInformation.ChannelMode.DUAL_MONO
if opus_packet.duration == 10:
frame_size = OpusMediaCodecInformation.FrameSize.FS_10MS
else:
frame_size = OpusMediaCodecInformation.FrameSize.FS_20MS
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
media_codec_information=OpusMediaCodecInformation(
channel_mode=channel_mode,
sampling_frequency=OpusMediaCodecInformation.SamplingFrequency.SF_48000,
frame_size=frame_size,
),
)
# -----------------------------------------------------------------------------
class Player:
def __init__(
self,
transport: str,
device_config: Optional[str],
authenticate: bool,
encrypt: bool,
) -> None:
self.transport = transport
self.device_config = device_config
self.authenticate = authenticate
self.encrypt = encrypt
self.avrcp_protocol: Optional[AvrcpProtocol] = None
self.done: Optional[asyncio.Event]
async def run(self, workload) -> None:
self.done = asyncio.Event()
try:
await self._run(workload)
except Exception as error:
print(color(f"!!! ERROR: {error}", "red"))
async def _run(self, workload) -> None:
async with await open_transport(self.transport) as (hci_source, hci_sink):
# Create a device
device_config = DeviceConfiguration()
if self.device_config:
device_config.load_from_file(self.device_config)
else:
device_config.name = "Bumble Player"
device_config.class_of_device = DeviceClass.pack_class_of_device(
DeviceClass.AUDIO_SERVICE_CLASS,
DeviceClass.AUDIO_VIDEO_MAJOR_DEVICE_CLASS,
DeviceClass.AUDIO_VIDEO_UNCATEGORIZED_MINOR_DEVICE_CLASS,
)
device_config.keystore = "JsonKeyStore"
device_config.classic_enabled = True
device_config.le_enabled = False
device_config.le_simultaneous_enabled = False
device_config.classic_sc_enabled = False
device_config.classic_smp_enabled = False
device = Device.from_config_with_hci(device_config, hci_source, hci_sink)
# Setup the SDP records to expose the SRC service
device.sdp_service_records = a2dp_source_sdp_records()
# Setup AVRCP
self.avrcp_protocol = AvrcpProtocol()
self.avrcp_protocol.listen(device)
# Don't require MITM when pairing.
device.pairing_config_factory = lambda connection: PairingConfig(mitm=False)
# Start the controller
await device.power_on()
# Print some of the config/properties
print(
"Player Bluetooth Address:",
color(
device.public_address.to_string(with_type_qualifier=False),
"yellow",
),
)
# Listen for connections
device.on("connection", self.on_bluetooth_connection)
# Run the workload
try:
await workload(device)
except BumbleConnectionError as error:
if error.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
print(color("Connection already established", "blue"))
else:
print(color(f"Failed to connect: {error}", "red"))
# Wait until it is time to exit
assert self.done is not None
await asyncio.wait(
[hci_source.terminated, asyncio.ensure_future(self.done.wait())],
return_when=asyncio.FIRST_COMPLETED,
)
def on_bluetooth_connection(self, connection: Connection) -> None:
print(color(f"--- Connected: {connection}", "cyan"))
connection.on("disconnection", self.on_bluetooth_disconnection)
def on_bluetooth_disconnection(self, reason) -> None:
print(color(f"--- Disconnected: {HCI_Constant.error_name(reason)}", "cyan"))
self.set_done()
async def connect(self, device: Device, address: str) -> Connection:
print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
# Request authentication
if self.authenticate:
print(color("*** Authenticating...", "blue"))
await connection.authenticate()
print(color("*** Authenticated", "blue"))
# Enable encryption
if self.encrypt:
print(color("*** Enabling encryption...", "blue"))
await connection.encrypt()
print(color("*** Encryption on", "blue"))
return connection
async def create_avdtp_protocol(self, connection: Connection) -> AvdtpProtocol:
# Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(connection)
if not avdtp_version:
raise RuntimeError("no A2DP service found")
print(color(f"AVDTP Version: {avdtp_version}"))
# Create a client to interact with the remote device
return await AvdtpProtocol.connect(connection, avdtp_version)
async def stream_packets(
self,
protocol: AvdtpProtocol,
codec_type: int,
vendor_id: int,
codec_id: int,
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
codec_capabilities: MediaCodecCapabilities,
):
# Discover all endpoints on the remote device
endpoints = await protocol.discover_remote_endpoints()
for endpoint in endpoints:
print('@@@', endpoint)
# Select a sink
sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, codec_type, vendor_id, codec_id
)
if sink is None:
print(color('!!! no compatible sink found', 'red'))
return
print(f'### Selected sink: {sink.seid}')
# Check if the sink supports delay reporting
delay_reporting = False
for capability in sink.capabilities:
if capability.service_category == AVDTP_DELAY_REPORTING_SERVICE_CATEGORY:
delay_reporting = True
break
def on_delay_report(delay: int):
print(color(f"*** DELAY REPORT: {delay}", "blue"))
# Adjust the codec capabilities for certain codecs
for capability in sink.capabilities:
if isinstance(capability, MediaCodecCapabilities):
if isinstance(
codec_capabilities.media_codec_information, SbcMediaCodecInformation
) and isinstance(
capability.media_codec_information, SbcMediaCodecInformation
):
codec_capabilities.media_codec_information.minimum_bitpool_value = (
capability.media_codec_information.minimum_bitpool_value
)
codec_capabilities.media_codec_information.maximum_bitpool_value = (
capability.media_codec_information.maximum_bitpool_value
)
print(color("Source media codec:", "green"), codec_capabilities)
# Stream the packets
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(codec_capabilities, packet_pump, delay_reporting)
source.on("delay_report", on_delay_report)
stream = await protocol.create_stream(source, sink)
await stream.start()
await packet_pump.wait_for_completion()
async def discover(self, device: Device) -> None:
@device.listens_to("inquiry_result")
def on_inquiry_result(
address: Address, class_of_device: int, data: AdvertisingData, rssi: int
) -> None:
(
service_classes,
major_device_class,
minor_device_class,
) = DeviceClass.split_class_of_device(class_of_device)
separator = "\n "
print(f">>> {color(address.to_string(False), 'yellow')}:")
print(f" Device Class (raw): {class_of_device:06X}")
major_class_name = DeviceClass.major_device_class_name(major_device_class)
print(" Device Major Class: " f"{major_class_name}")
minor_class_name = DeviceClass.minor_device_class_name(
major_device_class, minor_device_class
)
print(" Device Minor Class: " f"{minor_class_name}")
print(
" Device Services: "
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
)
print(f" RSSI: {rssi}")
if data.ad_structures:
print(f" {data.to_string(separator)}")
await device.start_discovery()
async def pair(self, device: Device, address: str) -> None:
print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
print(color("Pairing...", "magenta"))
await connection.authenticate()
print(color("Pairing completed", "magenta"))
self.set_done()
async def inquire(self, device: Device, address: str) -> None:
connection = await self.connect(device, address)
avdtp_protocol = await self.create_avdtp_protocol(connection)
# Discover the remote endpoints
endpoints = await avdtp_protocol.discover_remote_endpoints()
print(f'@@@ Found {len(list(endpoints))} endpoints')
for endpoint in endpoints:
print('@@@', endpoint)
self.set_done()
async def play(
self,
device: Device,
address: Optional[str],
audio_format: str,
audio_file: str,
) -> None:
if audio_format == "auto":
if audio_file.endswith(".sbc"):
audio_format = "sbc"
elif audio_file.endswith(".aac") or audio_file.endswith(".adts"):
audio_format = "aac"
elif audio_file.endswith(".ogg"):
audio_format = "opus"
else:
raise ValueError("Unable to determine audio format from file extension")
device.on(
"connection",
lambda connection: AsyncRunner.spawn(on_connection(connection)),
)
async def on_connection(connection: Connection):
avdtp_protocol = await self.create_avdtp_protocol(connection)
with open(audio_file, 'rb') as input_file:
# NOTE: this should be using asyncio file reading, but blocking reads
# are good enough for this command line app.
async def read_audio_data(byte_count):
return input_file.read(byte_count)
# Obtain the codec capabilities from the stream
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
vendor_id = 0
codec_id = 0
if audio_format == "sbc":
codec_type = A2DP_SBC_CODEC_TYPE
codec_capabilities = await sbc_codec_capabilities(read_audio_data)
packet_source = SbcPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
elif audio_format == "aac":
codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
codec_capabilities = await aac_codec_capabilities(read_audio_data)
packet_source = AacPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
else:
codec_type = A2DP_NON_A2DP_CODEC_TYPE
vendor_id = OpusMediaCodecInformation.VENDOR_ID
codec_id = OpusMediaCodecInformation.CODEC_ID
codec_capabilities = await opus_codec_capabilities(read_audio_data)
packet_source = OpusPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
# Rewind to the start
input_file.seek(0)
try:
await self.stream_packets(
avdtp_protocol,
codec_type,
vendor_id,
codec_id,
packet_source,
codec_capabilities,
)
except Exception as error:
print(color(f"!!! Error while streaming: {error}", "red"))
self.set_done()
if address:
await self.connect(device, address)
else:
print(color("Waiting for an incoming connection...", "magenta"))
def set_done(self) -> None:
if self.done:
self.done.set()
# -----------------------------------------------------------------------------
def create_player(context) -> Player:
return Player(
transport=context.obj["hci_transport"],
device_config=context.obj["device_config"],
authenticate=context.obj["authenticate"],
encrypt=context.obj["encrypt"],
)
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option("--hci-transport", metavar="TRANSPORT", required=True)
@click.option("--device-config", metavar="FILENAME", help="Device configuration file")
@click.option(
"--authenticate",
is_flag=True,
help="Request authentication when connecting",
default=False,
)
@click.option(
"--encrypt", is_flag=True, help="Request encryption when connecting", default=True
)
def player_cli(ctx, hci_transport, device_config, authenticate, encrypt):
ctx.ensure_object(dict)
ctx.obj["hci_transport"] = hci_transport
ctx.obj["device_config"] = device_config
ctx.obj["authenticate"] = authenticate
ctx.obj["encrypt"] = encrypt
@player_cli.command("discover")
@click.pass_context
def discover(context):
"""Discover speakers or headphones"""
player = create_player(context)
asyncio.run(player.run(player.discover))
@player_cli.command("inquire")
@click.pass_context
@click.argument(
"address",
metavar="ADDRESS",
)
def inquire(context, address):
"""Connect to a speaker or headphone and inquire about their capabilities"""
player = create_player(context)
asyncio.run(player.run(lambda device: player.inquire(device, address)))
@player_cli.command("pair")
@click.pass_context
@click.argument(
"address",
metavar="ADDRESS",
)
def pair(context, address):
"""Pair with a speaker or headphone"""
player = create_player(context)
asyncio.run(player.run(lambda device: player.pair(device, address)))
@player_cli.command("play")
@click.pass_context
@click.option(
"--connect",
"address",
metavar="ADDRESS",
help="Address or name to connect to",
)
@click.option(
"-f",
"--audio-format",
type=click.Choice(["auto", "sbc", "aac", "opus"]),
help="Audio file format (use 'auto' to infer the format from the file extension)",
default="auto",
)
@click.argument("audio_file")
def play(context, address, audio_format, audio_file):
"""Play and audio file"""
player = create_player(context)
asyncio.run(
player.run(
lambda device: player.play(device, address, audio_format, audio_file)
)
)
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
player_cli()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

View File

@@ -44,18 +44,25 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities,
MediaPacket,
Protocol,
)
from bumble.a2dp import (
MPEG_2_AAC_LC_OBJECT_TYPE,
make_audio_sink_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_SNR_ALLOCATION_METHOD,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation,
AacMediaCodecInformation,
)
from bumble.utils import AsyncRunner
from bumble.codecs import AacAudioRtpPacket
from bumble.rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -86,7 +93,7 @@ class AudioExtractor:
# -----------------------------------------------------------------------------
class AacAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
return AacAudioRtpPacket.from_bytes(packet.payload).to_adts()
return AacAudioRtpPacket(packet.payload).to_adts()
# -----------------------------------------------------------------------------
@@ -444,12 +451,10 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=AacMediaCodecInformation.SamplingFrequency.SF_48000
| AacMediaCodecInformation.SamplingFrequency.SF_44100,
channels=AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO,
media_codec_information=AacMediaCodecInformation.from_lists(
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
sampling_frequencies=[48000, 44100],
channels=[1, 2],
vbr=1,
bitrate=256000,
),
@@ -459,23 +464,20 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_4
| SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| SbcMediaCodecInformation.AllocationMethod.SNR,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),

View File

@@ -17,16 +17,12 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import AsyncGenerator
import dataclasses
import enum
import logging
import struct
from typing import Awaitable, Callable
from typing_extensions import ClassVar, Self
import logging
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable
from .codecs import AacAudioRtpPacket
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
DataElement,
@@ -46,7 +42,6 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number,
)
from .rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -108,8 +103,6 @@ SBC_ALLOCATION_METHOD_NAMES = {
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
8000,
11025,
@@ -137,9 +130,6 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
# fmt: on
@@ -267,61 +257,38 @@ class SbcMediaCodecInformation:
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
sampling_frequency: SamplingFrequency
channel_mode: ChannelMode
block_length: BlockLength
subbands: Subbands
allocation_method: AllocationMethod
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int
class SamplingFrequency(enum.IntFlag):
SF_16000 = 1 << 3
SF_32000 = 1 << 2
SF_44100 = 1 << 1
SF_48000 = 1 << 0
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
}
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
16000,
32000,
44100,
48000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class ChannelMode(enum.IntFlag):
MONO = 1 << 3
DUAL_CHANNEL = 1 << 2
STEREO = 1 << 1
JOINT_STEREO = 1 << 0
class BlockLength(enum.IntFlag):
BL_4 = 1 << 3
BL_8 = 1 << 2
BL_12 = 1 << 1
BL_16 = 1 << 0
class Subbands(enum.IntFlag):
S_4 = 1 << 1
S_8 = 1 << 0
class AllocationMethod(enum.IntFlag):
SNR = 1 << 1
LOUDNESS = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> Self:
sampling_frequency = cls.SamplingFrequency((data[0] >> 4) & 0x0F)
channel_mode = cls.ChannelMode((data[0] >> 0) & 0x0F)
block_length = cls.BlockLength((data[1] >> 4) & 0x0F)
subbands = cls.Subbands((data[1] >> 2) & 0x03)
allocation_method = cls.AllocationMethod((data[1] >> 0) & 0x03)
@staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return cls(
return SbcMediaCodecInformation(
sampling_frequency,
channel_mode,
block_length,
@@ -331,6 +298,52 @@ class SbcMediaCodecInformation:
maximum_bitpool_value,
)
@classmethod
def from_discrete_values(
cls,
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
@classmethod
def from_lists(
cls,
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self) -> bytes:
return bytes(
[
@@ -343,6 +356,23 @@ class SbcMediaCodecInformation:
]
)
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
@@ -351,66 +381,83 @@ class AacMediaCodecInformation:
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
object_type: ObjectType
sampling_frequency: SamplingFrequency
channels: Channels
object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int
class ObjectType(enum.IntFlag):
MPEG_2_AAC_LC = 1 << 7
MPEG_4_AAC_LC = 1 << 6
MPEG_4_AAC_LTP = 1 << 5
MPEG_4_AAC_SCALABLE = 1 << 4
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
11025: 1 << 10,
12000: 1 << 9,
16000: 1 << 8,
22050: 1 << 7,
24000: 1 << 6,
32000: 1 << 5,
44100: 1 << 4,
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1,
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
class SamplingFrequency(enum.IntFlag):
SF_8000 = 1 << 11
SF_11025 = 1 << 10
SF_12000 = 1 << 9
SF_16000 = 1 << 8
SF_22050 = 1 << 7
SF_24000 = 1 << 6
SF_32000 = 1 << 5
SF_44100 = 1 << 4
SF_48000 = 1 << 3
SF_64000 = 1 << 2
SF_88200 = 1 << 1
SF_96000 = 1 << 0
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
8000,
11025,
12000,
16000,
22050,
24000,
32000,
44100,
48000,
64000,
88200,
96000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class Channels(enum.IntFlag):
MONO = 1 << 1
STEREO = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> AacMediaCodecInformation:
object_type = cls.ObjectType(data[0])
sampling_frequency = cls.SamplingFrequency(
(data[1] << 4) | ((data[2] >> 4) & 0x0F)
)
channels = cls.Channels((data[2] >> 2) & 0x03)
@staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
object_type, sampling_frequency, channels, vbr, bitrate
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
@classmethod
def from_discrete_values(
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
@classmethod
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
def __bytes__(self) -> bytes:
@@ -425,6 +472,30 @@ class AacMediaCodecInformation:
]
)
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2]
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' ')',
]
)
@dataclasses.dataclass
# -----------------------------------------------------------------------------
@@ -443,7 +514,7 @@ class VendorSpecificMediaCodecInformation:
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self) -> str:
# pylint: disable=line-too-long
@@ -457,69 +528,13 @@ class VendorSpecificMediaCodecInformation:
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
vendor_id: int = dataclasses.field(init=False, repr=False)
codec_id: int = dataclasses.field(init=False, repr=False)
value: bytes = dataclasses.field(init=False, repr=False)
channel_mode: ChannelMode
frame_size: FrameSize
sampling_frequency: SamplingFrequency
class ChannelMode(enum.IntFlag):
MONO = 1 << 0
STEREO = 1 << 1
DUAL_MONO = 1 << 2
class FrameSize(enum.IntFlag):
FS_10MS = 1 << 0
FS_20MS = 1 << 1
class SamplingFrequency(enum.IntFlag):
SF_48000 = 1 << 0
VENDOR_ID: ClassVar[int] = 0x000000E0
CODEC_ID: ClassVar[int] = 0x0001
def __post_init__(self) -> None:
self.vendor_id = self.VENDOR_ID
self.codec_id = self.CODEC_ID
self.value = bytes(
[
self.channel_mode
| (self.frame_size << 3)
| (self.sampling_frequency << 7)
]
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
"""Create a new instance from the `value` part of the data, not including
the vendor id and codec id"""
channel_mode = cls.ChannelMode(data[0] & 0x07)
frame_size = cls.FrameSize((data[0] >> 3) & 0x03)
sampling_frequency = cls.SamplingFrequency((data[0] >> 7) & 0x01)
return cls(
channel_mode,
frame_size,
sampling_frequency,
)
def __str__(self) -> str:
return repr(self)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
sampling_frequency: int
block_count: int
channel_mode: int
allocation_method: int
subband_count: int
bitpool: int
payload: bytes
@property
@@ -538,10 +553,8 @@ class SbcFrame:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'am={self.allocation_method},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'bp={self.bitpool},'
f'size={len(self.payload)})'
)
@@ -570,7 +583,6 @@ class SbcParser:
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
allocation_method = (header[1] >> 1) & 1
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
@@ -590,13 +602,7 @@ class SbcParser:
# Emit the next frame
yield SbcFrame(
sampling_frequency,
blocks,
channel_mode,
allocation_method,
subbands,
bitpool,
payload,
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames()
@@ -604,15 +610,21 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
sample_count = 0
timestamp = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
@@ -620,29 +632,29 @@ class SbcPacketSource:
# NOTE: this doesn't support frame fragments
sbc_parser = SbcParser(self.read)
async for frame in sbc_parser.frames:
print(frame)
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
or len(frames) == 16
):
# Need to flush what has been accumulated so far
logger.debug(f"yielding {len(frames)} frames")
# Emit a packet
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
sbc_payload = bytes([len(frames)]) + b''.join(
[frame.payload for frame in frames]
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp_seconds
packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += sum((frame.sample_count for frame in frames))
timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
frames = [frame]
frames_size = len(frame.payload)
else:
@@ -651,315 +663,3 @@ class SbcPacketSource:
frames_size += len(frame.payload)
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacFrame:
class Profile(enum.IntEnum):
MAIN = 0
LC = 1
SSR = 2
LTP = 3
profile: Profile
sampling_frequency: int
channel_configuration: int
payload: bytes
@property
def sample_count(self) -> int:
return 1024
@property
def duration(self) -> float:
return self.sample_count / self.sampling_frequency
def __str__(self) -> str:
return (
f'AAC(sf={self.sampling_frequency},'
f'ch={self.channel_configuration},'
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
ADTS_AAC_SAMPLING_FREQUENCIES = [
96000,
88200,
64000,
48000,
44100,
32000,
24000,
22050,
16000,
12000,
11025,
8000,
7350,
0,
0,
0,
]
# -----------------------------------------------------------------------------
class AacParser:
"""Parser for AAC frames in an ADTS stream"""
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def frames(self) -> AsyncGenerator[AacFrame, None]:
async def generate_frames() -> AsyncGenerator[AacFrame, None]:
while True:
header = await self.read(7)
if not header:
return
sync_word = (header[0] << 4) | (header[1] >> 4)
if sync_word != 0b111111111111:
raise ValueError(f"invalid sync word ({sync_word:06x})")
layer = (header[1] >> 1) & 0b11
profile = AacFrame.Profile((header[2] >> 6) & 0b11)
sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[
(header[2] >> 2) & 0b1111
]
channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6)
frame_length = (
((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5)
)
if layer != 0:
raise ValueError("layer must be 0")
payload = await self.read(frame_length - 7)
if payload:
yield AacFrame(
profile, sampling_frequency, channel_configuration, payload
)
return generate_frames()
# -----------------------------------------------------------------------------
class AacPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
sample_count = 0
aac_parser = AacParser(self.read)
async for frame in aac_parser.frames:
logger.debug("yielding one AAC frame")
# Emit a packet
aac_payload = bytes(
AacAudioRtpPacket.for_simple_aac(
frame.sampling_frequency,
frame.channel_configuration,
frame.payload,
)
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
)
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += frame.sample_count
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusPacket:
class ChannelMode(enum.IntEnum):
MONO = 0
STEREO = 1
DUAL_MONO = 2
channel_mode: ChannelMode
duration: int # Duration in ms.
sampling_frequency: int
payload: bytes
def __str__(self) -> str:
return (
f'Opus(ch={self.channel_mode.name}, '
f'd={self.duration}ms, '
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class OpusParser:
"""
Parser for Opus packets in an Ogg stream
See RFC 3533
NOTE: this parser only supports bitstreams with a single logical stream.
"""
CAPTURE_PATTERN = b'OggS'
class HeaderType(enum.IntFlag):
CONTINUED = 0x01
FIRST = 0x02
LAST = 0x04
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def packets(self) -> AsyncGenerator[OpusPacket, None]:
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
packet = b''
packet_count = 0
expected_bitstream_serial_number = None
expected_page_sequence_number = 0
channel_mode = OpusPacket.ChannelMode.STEREO
while True:
# Parse the page header
header = await self.read(27)
if len(header) != 27:
logger.debug("end of stream")
break
capture_pattern = header[:4]
if capture_pattern != self.CAPTURE_PATTERN:
print(capture_pattern.hex())
raise ValueError("invalid capture pattern at start of page")
version = header[4]
if version != 0:
raise ValueError(f"version {version} not supported")
header_type = self.HeaderType(header[5])
(
granule_position,
bitstream_serial_number,
page_sequence_number,
crc_checksum,
page_segments,
) = struct.unpack_from("<QIIIB", header, 6)
segment_table = await self.read(page_segments)
if header_type & self.HeaderType.FIRST:
if expected_bitstream_serial_number is None:
# We will only accept pages for the first encountered stream
logger.debug("BOS")
expected_bitstream_serial_number = bitstream_serial_number
expected_page_sequence_number = page_sequence_number
if (
expected_bitstream_serial_number is None
or expected_bitstream_serial_number != bitstream_serial_number
):
logger.debug("skipping page (not the first logical bitstream)")
for lacing_value in segment_table:
if lacing_value:
await self.read(lacing_value)
continue
if expected_page_sequence_number != page_sequence_number:
raise ValueError(
f"expected page sequence number {expected_page_sequence_number}"
f" but got {page_sequence_number}"
)
expected_page_sequence_number = page_sequence_number + 1
# Assemble the page
if not header_type & self.HeaderType.CONTINUED:
packet = b''
for lacing_value in segment_table:
if lacing_value:
packet += await self.read(lacing_value)
if lacing_value < 255:
# End of packet
packet_count += 1
if packet_count == 1:
# The first packet contains the identification header
logger.debug("first packet (header)")
if packet[:8] != b"OpusHead":
raise ValueError("first packet is not OpusHead")
packet_count = (
OpusPacket.ChannelMode.MONO
if packet[9] == 1
else OpusPacket.ChannelMode.STEREO
)
elif packet_count == 2:
# The second packet contains the comment header
logger.debug("second packet (tags)")
if packet[:8] != b"OpusTags":
logger.warning("second packet is not OpusTags")
else:
yield OpusPacket(channel_mode, 20, 48000, packet)
packet = b''
if header_type & self.HeaderType.LAST:
logger.debug("EOS")
return generate_frames()
# -----------------------------------------------------------------------------
class OpusPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
elapsed_ms = 0
opus_parser = OpusParser(self.read)
async for opus_packet in opus_parser.packets:
# We only support sending one Opus frame per RTP packet
# TODO: check the spec for the first byte value here
opus_payload = bytes([1]) + opus_packet.payload
elapsed_s = elapsed_ms / 1000
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
rtp_packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
)
rtp_packet.timestamp_seconds = elapsed_s
yield rtp_packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
elapsed_ms += opus_packet.duration
return generate_packets()
# -----------------------------------------------------------------------------
# This map should be left at the end of the file so it can refer to the classes
# above
# -----------------------------------------------------------------------------
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
OpusMediaCodecInformation.VENDOR_ID: {
OpusMediaCodecInformation.CODEC_ID: OpusMediaCodecInformation
}
}

View File

@@ -23,7 +23,6 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import functools
import inspect
@@ -42,7 +41,6 @@ from typing import (
from pyee import EventEmitter
from bumble import utils
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
@@ -147,57 +145,43 @@ ATT_RESPONSES = [
ATT_EXECUTE_WRITE_RESPONSE
]
class ErrorCode(utils.OpenIntEnum):
'''
See
ATT_INVALID_HANDLE_ERROR = 0x01
ATT_READ_NOT_PERMITTED_ERROR = 0x02
ATT_WRITE_NOT_PERMITTED_ERROR = 0x03
ATT_INVALID_PDU_ERROR = 0x04
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = 0x05
ATT_REQUEST_NOT_SUPPORTED_ERROR = 0x06
ATT_INVALID_OFFSET_ERROR = 0x07
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = 0x08
ATT_PREPARE_QUEUE_FULL_ERROR = 0x09
ATT_ATTRIBUTE_NOT_FOUND_ERROR = 0x0A
ATT_ATTRIBUTE_NOT_LONG_ERROR = 0x0B
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = 0x0C
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = 0x0D
ATT_UNLIKELY_ERROR_ERROR = 0x0E
ATT_INSUFFICIENT_ENCRYPTION_ERROR = 0x0F
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = 0x10
ATT_INSUFFICIENT_RESOURCES_ERROR = 0x11
* Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
* Core Specification Supplement: Common Profile And Service Error Codes
'''
INVALID_HANDLE = 0x01
READ_NOT_PERMITTED = 0x02
WRITE_NOT_PERMITTED = 0x03
INVALID_PDU = 0x04
INSUFFICIENT_AUTHENTICATION = 0x05
REQUEST_NOT_SUPPORTED = 0x06
INVALID_OFFSET = 0x07
INSUFFICIENT_AUTHORIZATION = 0x08
PREPARE_QUEUE_FULL = 0x09
ATTRIBUTE_NOT_FOUND = 0x0A
ATTRIBUTE_NOT_LONG = 0x0B
INSUFFICIENT_ENCRYPTION_KEY_SIZE = 0x0C
INVALID_ATTRIBUTE_LENGTH = 0x0D
UNLIKELY_ERROR = 0x0E
INSUFFICIENT_ENCRYPTION = 0x0F
UNSUPPORTED_GROUP_TYPE = 0x10
INSUFFICIENT_RESOURCES = 0x11
DATABASE_OUT_OF_SYNC = 0x12
VALUE_NOT_ALLOWED = 0x13
# 0x80 0x9F: Application Error
# 0xE0 0xFF: Common Profile and Service Error Codes
WRITE_REQUEST_REJECTED = 0xFC
CCCD_IMPROPERLY_CONFIGURED = 0xFD
PROCEDURE_ALREADY_IN_PROGRESS = 0xFE
OUT_OF_RANGE = 0xFF
# Backward Compatible Constants
ATT_INVALID_HANDLE_ERROR = ErrorCode.INVALID_HANDLE
ATT_READ_NOT_PERMITTED_ERROR = ErrorCode.READ_NOT_PERMITTED
ATT_WRITE_NOT_PERMITTED_ERROR = ErrorCode.WRITE_NOT_PERMITTED
ATT_INVALID_PDU_ERROR = ErrorCode.INVALID_PDU
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = ErrorCode.INSUFFICIENT_AUTHENTICATION
ATT_REQUEST_NOT_SUPPORTED_ERROR = ErrorCode.REQUEST_NOT_SUPPORTED
ATT_INVALID_OFFSET_ERROR = ErrorCode.INVALID_OFFSET
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = ErrorCode.INSUFFICIENT_AUTHORIZATION
ATT_PREPARE_QUEUE_FULL_ERROR = ErrorCode.PREPARE_QUEUE_FULL
ATT_ATTRIBUTE_NOT_FOUND_ERROR = ErrorCode.ATTRIBUTE_NOT_FOUND
ATT_ATTRIBUTE_NOT_LONG_ERROR = ErrorCode.ATTRIBUTE_NOT_LONG
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION_KEY_SIZE
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = ErrorCode.INVALID_ATTRIBUTE_LENGTH
ATT_UNLIKELY_ERROR_ERROR = ErrorCode.UNLIKELY_ERROR
ATT_INSUFFICIENT_ENCRYPTION_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = ErrorCode.UNSUPPORTED_GROUP_TYPE
ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
ATT_ERROR_NAMES = {
ATT_INVALID_HANDLE_ERROR: 'ATT_INVALID_HANDLE_ERROR',
ATT_READ_NOT_PERMITTED_ERROR: 'ATT_READ_NOT_PERMITTED_ERROR',
ATT_WRITE_NOT_PERMITTED_ERROR: 'ATT_WRITE_NOT_PERMITTED_ERROR',
ATT_INVALID_PDU_ERROR: 'ATT_INVALID_PDU_ERROR',
ATT_INSUFFICIENT_AUTHENTICATION_ERROR: 'ATT_INSUFFICIENT_AUTHENTICATION_ERROR',
ATT_REQUEST_NOT_SUPPORTED_ERROR: 'ATT_REQUEST_NOT_SUPPORTED_ERROR',
ATT_INVALID_OFFSET_ERROR: 'ATT_INVALID_OFFSET_ERROR',
ATT_INSUFFICIENT_AUTHORIZATION_ERROR: 'ATT_INSUFFICIENT_AUTHORIZATION_ERROR',
ATT_PREPARE_QUEUE_FULL_ERROR: 'ATT_PREPARE_QUEUE_FULL_ERROR',
ATT_ATTRIBUTE_NOT_FOUND_ERROR: 'ATT_ATTRIBUTE_NOT_FOUND_ERROR',
ATT_ATTRIBUTE_NOT_LONG_ERROR: 'ATT_ATTRIBUTE_NOT_LONG_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR',
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR: 'ATT_INVALID_ATTRIBUTE_LENGTH_ERROR',
ATT_UNLIKELY_ERROR_ERROR: 'ATT_UNLIKELY_ERROR_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_ERROR',
ATT_UNSUPPORTED_GROUP_TYPE_ERROR: 'ATT_UNSUPPORTED_GROUP_TYPE_ERROR',
ATT_INSUFFICIENT_RESOURCES_ERROR: 'ATT_INSUFFICIENT_RESOURCES_ERROR'
}
ATT_DEFAULT_MTU = 23
@@ -261,9 +245,9 @@ class ATT_PDU:
def pdu_name(op_code):
return name_or_number(ATT_PDU_NAMES, op_code, 2)
@classmethod
def error_name(cls, error_code: int) -> str:
return ErrorCode(error_code).name
@staticmethod
def error_name(error_code):
return name_or_number(ATT_ERROR_NAMES, error_code, 2)
@staticmethod
def subclass(fields):
@@ -811,7 +795,7 @@ class Attribute(EventEmitter):
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str}\nGot: {permissions_str}"
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc
# Permission flags(legacy-use only)

View File

@@ -119,7 +119,7 @@ class Frame:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5 or subunit_id == 7:
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
@@ -132,6 +132,7 @@ class Frame:
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise core.InvalidPacketError("reserved subunit ID")

View File

@@ -17,10 +17,12 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import struct
import time
import logging
import enum
import warnings
from pyee import EventEmitter
from typing import (
Any,
Awaitable,
@@ -37,8 +39,6 @@ from typing import (
cast,
)
from pyee import EventEmitter
from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
@@ -51,16 +51,13 @@ from .a2dp import (
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE,
A2DP_SBC_CODEC_TYPE,
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES,
AacMediaCodecInformation,
SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation,
)
from .rtp import MediaPacket
from . import sdp, device, l2cap
from .colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -281,6 +278,95 @@ class RealtimeClock:
await asyncio.sleep(duration)
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number & 0xFFFF
self.timestamp = timestamp & 0xFFFFFFFF
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack(
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class MediaPacketPump:
pump_task: Optional[asyncio.Task]
@@ -291,7 +377,6 @@ class MediaPacketPump:
self.packets = packets
self.clock = clock
self.pump_task = None
self.completed = asyncio.Event()
async def start(self, rtp_channel: l2cap.ClassicChannel) -> None:
async def pump_packets():
@@ -321,8 +406,6 @@ class MediaPacketPump:
)
except asyncio.exceptions.CancelledError:
logger.debug('pump canceled')
finally:
self.completed.set()
# Pump packets
self.pump_task = asyncio.create_task(pump_packets())
@@ -334,9 +417,6 @@ class MediaPacketPump:
await self.pump_task
self.pump_task = None
async def wait_for_completion(self) -> None:
await self.completed.wait()
# -----------------------------------------------------------------------------
class MessageAssembler:
@@ -500,10 +580,10 @@ class ServiceCapabilities:
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details: Optional[List[str]] = None) -> str:
def to_string(self, details: List[str] = []) -> str:
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ (details or [])
+ details
)
return f'ServiceCapabilities({attributes})'
@@ -535,25 +615,11 @@ class MediaCodecCapabilities(ServiceCapabilities):
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
vendor_media_codec_information = (
self.media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
self.media_codec_information = media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
else:
self.media_codec_information = vendor_media_codec_information
def __init__(
self,
@@ -1250,20 +1316,10 @@ class Protocol(EventEmitter):
return None
def add_source(
self,
codec_capabilities: MediaCodecCapabilities,
packet_pump: MediaPacketPump,
delay_reporting: bool = False,
self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump
) -> LocalSource:
seid = len(self.local_endpoints) + 1
service_capabilities = (
[ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)]
if delay_reporting
else []
)
source = LocalSource(
self, seid, codec_capabilities, service_capabilities, packet_pump
)
source = LocalSource(self, seid, codec_capabilities, packet_pump)
self.local_endpoints.append(source)
return source
@@ -1316,7 +1372,7 @@ class Protocol(EventEmitter):
return self.remote_endpoints.values()
def find_remote_sink_by_codec(
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
self, media_type: int, codec_type: int
) -> Optional[DiscoveredStreamEndPoint]:
for endpoint in self.remote_endpoints.values():
if (
@@ -1341,19 +1397,7 @@ class Protocol(EventEmitter):
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and codec_capabilities.media_codec_type == codec_type
):
if isinstance(
codec_capabilities.media_codec_information,
VendorSpecificMediaCodecInformation,
):
if (
codec_capabilities.media_codec_information.vendor_id
== vendor_id
and codec_capabilities.media_codec_information.codec_id
== codec_id
):
has_codec = True
else:
has_codec = True
has_codec = True
if has_media_transport and has_codec:
return endpoint
@@ -2136,13 +2180,12 @@ class LocalSource(LocalStreamEndPoint):
protocol: Protocol,
seid: int,
codec_capabilities: MediaCodecCapabilities,
other_capabilitiles: Iterable[ServiceCapabilities],
packet_pump: MediaPacketPump,
) -> None:
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities,
] + list(other_capabilitiles)
]
super().__init__(
protocol,
seid,

View File

@@ -1491,14 +1491,10 @@ class Protocol(pyee.EventEmitter):
f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}"
)
# Only addressing the unit, or the PANEL subunit with subunit ID 0 is supported
# in this profile.
if not (
command.subunit_type == avc.Frame.SubunitType.UNIT
and command.subunit_id == 7
) and not (
command.subunit_type == avc.Frame.SubunitType.PANEL
and command.subunit_id == 0
# Only the PANEL subunit type with subunit ID 0 is supported in this profile.
if (
command.subunit_type != avc.Frame.SubunitType.PANEL
or command.subunit_id != 0
):
logger.debug("subunit not supported")
self.send_not_implemented_response(transaction_label, command)
@@ -1532,8 +1528,8 @@ class Protocol(pyee.EventEmitter):
# TODO: delegate
response = avc.PassThroughResponseFrame(
avc.ResponseFrame.ResponseCode.ACCEPTED,
command.subunit_type,
command.subunit_id,
avc.Frame.SubunitType.PANEL,
0,
command.state_flag,
command.operation_id,
command.operation_data,
@@ -1850,15 +1846,6 @@ class Protocol(pyee.EventEmitter):
RejectedResponse(pdu_id, status_code),
)
def send_not_implemented_avrcp_response(
self, transaction_label: int, pdu_id: Protocol.PduId
) -> None:
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED,
NotImplementedResponse(pdu_id, b''),
)
def _on_get_capabilities_command(
self, transaction_label: int, command: GetCapabilitiesCommand
) -> None:
@@ -1904,35 +1891,29 @@ class Protocol(pyee.EventEmitter):
async def register_notification():
# Check if the event is supported.
supported_events = await self.delegate.get_supported_events()
if command.event_id not in supported_events:
logger.debug("event not supported")
self.send_not_implemented_avrcp_response(
transaction_label, self.PduId.REGISTER_NOTIFICATION
)
return
if command.event_id in supported_events:
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
# TODO: testing only, use delegate
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING)
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
# TODO: testing only, use delegate
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING)
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
self._delegate_command(transaction_label, command, register_notification())

View File

@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
from typing_extensions import Self
from bumble import core
@@ -102,40 +101,12 @@ class BitReader:
break
# -----------------------------------------------------------------------------
class BitWriter:
"""Simple but not optimized bit stream writer."""
data: int
bit_count: int
def __init__(self) -> None:
self.data = 0
self.bit_count = 0
def write(self, value: int, bit_count: int) -> None:
self.data = (self.data << bit_count) | value
self.bit_count += bit_count
def write_bytes(self, data: bytes) -> None:
bit_count = 8 * len(data)
self.data = (self.data << bit_count) | int.from_bytes(data, 'big')
self.bit_count += bit_count
def __bytes__(self) -> bytes:
return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes(
(self.bit_count + 7) // 8, 'big'
)
# -----------------------------------------------------------------------------
class AacAudioRtpPacket:
"""AAC payload encapsulated in an RTP packet payload"""
audio_mux_element: AudioMuxElement
@staticmethod
def read_latm_value(reader: BitReader) -> int:
def latm_value(reader: BitReader) -> int:
bytes_for_value = reader.read(2)
value = 0
for _ in range(bytes_for_value + 1):
@@ -143,33 +114,24 @@ class AacAudioRtpPacket:
return value
@staticmethod
def read_audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return audio_object_type
def program_config_element(reader: BitReader):
raise core.InvalidPacketError('program_config_element not supported')
@dataclass
class GASpecificConfig:
audio_object_type: int
# NOTE: other fields not supported
@classmethod
def from_bits(
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self:
def __init__(
self, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> None:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
core_coder_delay = reader.read(14)
self.core_coder_delay = reader.read(14)
extension_flag = reader.read(1)
if not channel_configuration:
raise core.InvalidPacketError('program_config_element not supported')
AacAudioRtpPacket.program_config_element(reader)
if audio_object_type in (6, 20):
layer_nr = reader.read(3)
self.layer_nr = reader.read(3)
if extension_flag:
if audio_object_type == 22:
num_of_sub_frame = reader.read(5)
@@ -182,13 +144,14 @@ class AacAudioRtpPacket:
if extension_flag_3 == 1:
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
return cls(audio_object_type)
@staticmethod
def audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
def to_bits(self, writer: BitWriter) -> None:
assert self.audio_object_type in (1, 2)
writer.write(0, 1) # frame_length_flag = 0
writer.write(0, 1) # depends_on_core_coder = 0
writer.write(0, 1) # extension_flag = 0
return audio_object_type
@dataclass
class AudioSpecificConfig:
@@ -196,7 +159,6 @@ class AacAudioRtpPacket:
sampling_frequency_index: int
sampling_frequency: int
channel_configuration: int
ga_specific_config: AacAudioRtpPacket.GASpecificConfig
sbr_present_flag: int
ps_present_flag: int
extension_audio_object_type: int
@@ -220,73 +182,44 @@ class AacAudioRtpPacket:
7350,
]
@classmethod
def for_simple_aac(
cls,
audio_object_type: int,
sampling_frequency: int,
channel_configuration: int,
) -> Self:
if sampling_frequency not in cls.SAMPLING_FREQUENCIES:
raise ValueError(f'invalid sampling frequency {sampling_frequency}')
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type)
return cls(
audio_object_type=audio_object_type,
sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index(
sampling_frequency
),
sampling_frequency=sampling_frequency,
channel_configuration=channel_configuration,
ga_specific_config=ga_specific_config,
sbr_present_flag=0,
ps_present_flag=0,
extension_audio_object_type=0,
extension_sampling_frequency_index=0,
extension_sampling_frequency=0,
extension_channel_configuration=0,
)
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
def __init__(self, reader: BitReader) -> None:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
sampling_frequency_index = reader.read(4)
if sampling_frequency_index == 0xF:
sampling_frequency = reader.read(24)
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
self.sampling_frequency_index = reader.read(4)
if self.sampling_frequency_index == 0xF:
self.sampling_frequency = reader.read(24)
else:
sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index]
channel_configuration = reader.read(4)
sbr_present_flag = 0
ps_present_flag = 0
extension_sampling_frequency_index = 0
extension_sampling_frequency = 0
extension_channel_configuration = 0
extension_audio_object_type = 0
if audio_object_type in (5, 29):
extension_audio_object_type = 5
sbr_present_flag = 1
if audio_object_type == 29:
ps_present_flag = 1
extension_sampling_frequency_index = reader.read(4)
if extension_sampling_frequency_index == 0xF:
extension_sampling_frequency = reader.read(24)
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
self.sampling_frequency_index
]
self.channel_configuration = reader.read(4)
self.sbr_present_flag = -1
self.ps_present_flag = -1
if self.audio_object_type in (5, 29):
self.extension_audio_object_type = 5
self.sbc_present_flag = 1
if self.audio_object_type == 29:
self.ps_present_flag = 1
self.extension_sampling_frequency_index = reader.read(4)
if self.extension_sampling_frequency_index == 0xF:
self.extension_sampling_frequency = reader.read(24)
else:
extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[
extension_sampling_frequency_index
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
self.extension_sampling_frequency_index
]
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
if audio_object_type == 22:
extension_channel_configuration = reader.read(4)
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
if self.audio_object_type == 22:
self.extension_channel_configuration = reader.read(4)
else:
self.extension_audio_object_type = 0
if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits(
reader, channel_configuration, audio_object_type
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
reader, self.channel_configuration, self.audio_object_type
)
else:
raise core.InvalidPacketError(
f'audioObjectType {audio_object_type} not supported'
f'audioObjectType {self.audio_object_type} not supported'
)
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
@@ -315,44 +248,13 @@ class AacAudioRtpPacket:
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
# self.extension_channel_configuration = reader.read(4)
return cls(
audio_object_type,
sampling_frequency_index,
sampling_frequency,
channel_configuration,
ga_specific_config,
sbr_present_flag,
ps_present_flag,
extension_audio_object_type,
extension_sampling_frequency_index,
extension_sampling_frequency,
extension_channel_configuration,
)
def to_bits(self, writer: BitWriter) -> None:
if self.sampling_frequency_index >= 15:
raise ValueError(
f"unsupported sampling frequency index {self.sampling_frequency_index}"
)
if self.audio_object_type not in (1, 2):
raise ValueError(
f"unsupported audio object type {self.audio_object_type} "
)
writer.write(self.audio_object_type, 5)
writer.write(self.sampling_frequency_index, 4)
writer.write(self.channel_configuration, 4)
self.ga_specific_config.to_bits(writer)
@dataclass
class StreamMuxConfig:
other_data_present: int
other_data_len_bits: int
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
def __init__(self, reader: BitReader) -> None:
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
audio_mux_version = reader.read(1)
if audio_mux_version == 1:
@@ -362,7 +264,7 @@ class AacAudioRtpPacket:
if audio_mux_version_a != 0:
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
@@ -373,13 +275,13 @@ class AacAudioRtpPacket:
if num_layer != 0:
raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0:
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
)
else:
asc_len = AacAudioRtpPacket.read_latm_value(reader)
asc_len = AacAudioRtpPacket.latm_value(reader)
marker = reader.bit_position
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
)
audio_specific_config_len = reader.bit_position - marker
@@ -397,49 +299,36 @@ class AacAudioRtpPacket:
f'frame_length_type {frame_length_type} not supported'
)
other_data_present = reader.read(1)
other_data_len_bits = 0
if other_data_present:
self.other_data_present = reader.read(1)
if self.other_data_present:
if audio_mux_version == 1:
other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader)
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
else:
self.other_data_len_bits = 0
while True:
other_data_len_bits *= 256
self.other_data_len_bits *= 256
other_data_len_esc = reader.read(1)
other_data_len_bits += reader.read(8)
self.other_data_len_bits += reader.read(8)
if other_data_len_esc == 0:
break
crc_check_present = reader.read(1)
if crc_check_present:
crc_checksum = reader.read(8)
return cls(other_data_present, other_data_len_bits, audio_specific_config)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # audioMuxVersion = 0
writer.write(1, 1) # allStreamsSameTimeFraming = 1
writer.write(0, 6) # numSubFrames = 0
writer.write(0, 4) # numProgram = 0
writer.write(0, 3) # numLayer = 0
self.audio_specific_config.to_bits(writer)
writer.write(0, 3) # frameLengthType = 0
writer.write(0, 8) # latmBufferFullness = 0
writer.write(0, 1) # otherDataPresent = 0
writer.write(0, 1) # crcCheckPresent = 0
@dataclass
class AudioMuxElement:
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
payload: bytes
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
# (only supports mux_config_present=1)
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader)
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support:
# allStreamsSameTimeFraming == 1
@@ -455,46 +344,19 @@ class AacAudioRtpPacket:
if tmp != 255:
break
payload = reader.read_bytes(mux_slot_length_bytes)
self.payload = reader.read_bytes(mux_slot_length_bytes)
if stream_mux_config.other_data_present:
reader.skip(stream_mux_config.other_data_len_bits)
if self.stream_mux_config.other_data_present:
reader.skip(self.stream_mux_config.other_data_len_bits)
# ByteAlign
while reader.bit_position % 8:
reader.read(1)
return cls(stream_mux_config, payload)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # useSameStreamMux = 0
self.stream_mux_config.to_bits(writer)
mux_slot_length_bytes = len(self.payload)
while mux_slot_length_bytes > 255:
writer.write(255, 8)
mux_slot_length_bytes -= 255
writer.write(mux_slot_length_bytes, 8)
if mux_slot_length_bytes == 255:
writer.write(0, 8)
writer.write_bytes(self.payload)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
def __init__(self, data: bytes) -> None:
# Parse the bit stream
reader = BitReader(data)
return cls(cls.AudioMuxElement.from_bits(reader))
@classmethod
def for_simple_aac(
cls, sampling_frequency: int, channel_configuration: int, payload: bytes
) -> Self:
audio_specific_config = cls.AudioSpecificConfig.for_simple_aac(
2, sampling_frequency, channel_configuration
)
stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config)
audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload)
return cls(audio_mux_element)
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
def to_adts(self):
# pylint: disable=line-too-long
@@ -521,11 +383,3 @@ class AacAudioRtpPacket:
)
+ self.audio_mux_element.payload
)
def __init__(self, audio_mux_element: AudioMuxElement) -> None:
self.audio_mux_element = audio_mux_element
def __bytes__(self) -> bytes:
writer = BitWriter()
self.audio_mux_element.to_bits(writer)
return bytes(writer)

View File

@@ -148,10 +148,6 @@ class InvalidOperationError(BaseBumbleError, RuntimeError):
"""Invalid Operation Error"""
class NotSupportedError(BaseBumbleError, RuntimeError):
"""Not Supported"""
class OutOfResourcesError(BaseBumbleError, RuntimeError):
"""Out of Resources Error"""

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -151,7 +149,7 @@ QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder:
class G722Decoder(object):
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
@@ -159,7 +157,7 @@ class G722Decoder:
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self) -> None:
def __init__(self):
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
@@ -167,12 +165,12 @@ class G722Decoder:
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
def decode_frame(self, encoded_data) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
def g722_decode(self, result_array, encoded_data) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
@@ -200,16 +198,14 @@ class G722Decoder:
return result_length
def update_decoded_result(
self, xout: int, byte_length: int, byte_array: bytearray
) -> int:
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits: int) -> int:
def lower_sub_band_decoder(self, lower_bits) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
@@ -262,7 +258,7 @@ class G722Decoder:
return rlow
def higher_sub_band_decoder(self, higher_bits: int) -> int:
def higher_sub_band_decoder(self, higher_bits) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
@@ -310,14 +306,14 @@ class G722Decoder:
# -----------------------------------------------------------------------------
class Band:
"""Structure for G722 decode processing."""
class Band(object):
"""Structure for G722 decode proccessing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self) -> None:
def __init__(self):
self._sp = 0
self._sz = 0
self._r = [0] * 3

View File

@@ -51,7 +51,6 @@ from typing_extensions import Self
from pyee import EventEmitter
from bumble import hci
from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from .gatt import Characteristic, Descriptor, Service
@@ -113,7 +112,6 @@ from .hci import (
HCI_LE_Periodic_Advertising_Create_Sync_Command,
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command,
HCI_LE_Periodic_Advertising_Report_Event,
HCI_LE_Periodic_Advertising_Sync_Transfer_Command,
HCI_LE_Periodic_Advertising_Terminate_Sync_Command,
HCI_LE_Enable_Encryption_Command,
HCI_LE_Extended_Advertising_Report_Event,
@@ -170,12 +168,11 @@ from .hci import (
OwnAddressType,
LeFeature,
LeFeatureMask,
LmpFeatureMask,
Phy,
phy_list_to_bits,
)
from .host import Host
from .profiles.gap import GenericAccessService
from .gap import GenericAccessService
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
@@ -190,7 +187,6 @@ from .core import (
InvalidArgumentError,
InvalidOperationError,
InvalidStateError,
NotSupportedError,
OutOfResourcesError,
UnreachableError,
)
@@ -207,13 +203,13 @@ from .keys import (
KeyStore,
PairingKeys,
)
from bumble import pairing
from bumble import gatt_client
from bumble import gatt_server
from bumble import smp
from bumble import sdp
from bumble import l2cap
from bumble import core
from .pairing import PairingConfig
from . import gatt_client
from . import gatt_server
from . import smp
from . import sdp
from . import l2cap
from . import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
@@ -972,25 +968,20 @@ class PeriodicAdvertisingSync(EventEmitter):
response = await self.device.send_command(
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(),
)
if response.return_parameters == HCI_SUCCESS:
if response.status == HCI_SUCCESS:
if self in self.device.periodic_advertising_syncs:
self.device.periodic_advertising_syncs.remove(self)
return
if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST):
self.state = self.State.TERMINATED
if self.sync_handle is not None:
await self.device.send_command(
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
sync_handle=self.sync_handle
)
await self.device.send_command(
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
sync_handle=self.sync_handle
)
)
self.device.periodic_advertising_syncs.remove(self)
async def transfer(self, connection: Connection, service_data: int = 0) -> None:
if self.sync_handle is not None:
await connection.transfer_periodic_sync(self.sync_handle, service_data)
def on_establishment(
self,
status,
@@ -1216,13 +1207,8 @@ class Peer:
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
if proxy := proxy_class.from_client(self.gatt_client):
return cast(_PROXY_CLASS, proxy)
return None
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
@@ -1507,9 +1493,11 @@ class Connection(CompositeEventEmitter):
try:
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
finally:
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
except asyncio.TimeoutError:
pass
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)
@@ -1540,11 +1528,6 @@ class Connection(CompositeEventEmitter):
async def get_phy(self):
return await self.device.get_connection_phy(self)
async def transfer_periodic_sync(
self, sync_handle: int, service_data: int = 0
) -> None:
await self.device.transfer_periodic_sync(self, sync_handle, service_data)
# [Classic only]
async def request_remote_name(self):
return await self.device.request_remote_name(self)
@@ -1571,22 +1554,14 @@ class Connection(CompositeEventEmitter):
raise
def __str__(self):
if self.transport == BT_LE_TRANSPORT:
return (
f'Connection(transport=LE, handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
)
else:
return (
f'Connection(transport=BR/EDR, handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'peer_address={self.peer_address})'
)
return (
f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
)
# -----------------------------------------------------------------------------
@@ -1608,7 +1583,6 @@ class DeviceConfiguration:
classic_ssp_enabled: bool = True
classic_smp_enabled: bool = True
classic_accept_any: bool = True
classic_interlaced_scan_enabled: bool = True
connectable: bool = True
discoverable: bool = True
advertising_data: bytes = bytes(
@@ -1621,8 +1595,6 @@ class DeviceConfiguration:
address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
@@ -1774,9 +1746,9 @@ device_host_event_handlers: List[str] = []
# -----------------------------------------------------------------------------
class Device(CompositeEventEmitter):
# Incomplete list of fields.
random_address: Address # Random private address that may change periodically
public_address: Address # Public address that is globally unique (from controller)
static_address: Address # Random static address that does not change once set
random_address: Address # Random address that may change with RPA
public_address: Address # Public address (obtained from the controller)
static_address: Address # Random address that can be set but does not change
classic_enabled: bool
name: str
class_of_device: int
@@ -1923,7 +1895,6 @@ class Device(CompositeEventEmitter):
self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_smp_enabled = config.classic_smp_enabled
self.classic_interlaced_scan_enabled = config.classic_interlaced_scan_enabled
self.discoverable = config.discoverable
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
@@ -1988,19 +1959,7 @@ class Device(CompositeEventEmitter):
# Setup SMP
self.smp_manager = smp.Manager(
self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
self, pairing_config_factory=lambda connection: PairingConfig()
)
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2224,15 +2183,10 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
),
check_result=True,
)
)
if self.le_enabled:
# Generate a random address if not set.
if self.static_address == Address.ANY_RANDOM:
self.static_address = Address.generate_static_address()
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
@@ -2242,8 +2196,23 @@ class Device(CompositeEventEmitter):
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
else:
self.random_address = self.static_address
# Set the controller address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
)
# Ensure the address bytes can be a static random address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
if self.random_address != Address.ANY_RANDOM:
logger.debug(
@@ -2268,8 +2237,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
),
check_result=True,
)
)
if self.cis_enabled:
@@ -2277,8 +2245,7 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1,
),
check_result=True,
)
)
if self.classic_enabled:
@@ -2301,21 +2268,6 @@ class Device(CompositeEventEmitter):
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
if self.classic_interlaced_scan_enabled:
if self.host.supports_lmp_features(LmpFeatureMask.INTERLACED_PAGE_SCAN):
await self.send_command(
hci.HCI_Write_Page_Scan_Type_Command(page_scan_type=1),
check_result=True,
)
if self.host.supports_lmp_features(
LmpFeatureMask.INTERLACED_INQUIRY_SCAN
):
await self.send_command(
hci.HCI_Write_Inquiry_Scan_Type_Command(scan_type=1),
check_result=True,
)
# Done
self.powered_on = True
@@ -2413,10 +2365,6 @@ class Device(CompositeEventEmitter):
def supports_le_extended_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING)
@property
def supports_le_periodic_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_PERIODIC_ADVERTISING)
async def start_advertising(
self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
@@ -2819,10 +2767,6 @@ class Device(CompositeEventEmitter):
sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT,
filter_duplicates: bool = False,
) -> PeriodicAdvertisingSync:
# Check that the controller supports the feature.
if not self.supports_le_periodic_advertising:
raise NotSupportedError()
# Check that there isn't already an equivalent entry
if any(
sync.advertiser_address == advertiser_address and sync.sid == sid
@@ -3020,47 +2964,18 @@ class Device(CompositeEventEmitter):
] = None,
own_address_type: int = OwnAddressType.RANDOM,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
always_resolve: bool = False,
) -> Connection:
'''
Request a connection to a peer.
When the transport is BLE, this method cannot be called if there is already a
When transport is BLE, this method cannot be called if there is already a
pending connection.
Args:
peer_address:
Address or name of the device to connect to.
If a string is passed:
If the string is an address followed by a `@` suffix, the `always_resolve`
argument is implicitly set to True, so the connection is made to the
address after resolution.
If the string is any other address, the connection is made to that
address (with or without address resolution, depending on the
`always_resolve` argument).
For any other string, a scan for devices using that string as their name
is initiated, and a connection to the first matching device's address
is made. In that case, `always_resolve` is ignored.
connection_parameters_preferences: (BLE only, ignored for BR/EDR)
* None: use the 1M PHY with default parameters
* map: each entry has a PHY as key and a ConnectionParametersPreferences
object as value
connection_parameters_preferences:
(BLE only, ignored for BR/EDR)
* None: use the 1M PHY with default parameters
* map: each entry has a PHY as key and a ConnectionParametersPreferences
object as value
own_address_type:
(BLE only, ignored for BR/EDR)
OwnAddressType.RANDOM to use this device's random address, or
OwnAddressType.PUBLIC to use this device's public address.
timeout:
Maximum time to wait for a connection to be established, in seconds.
Pass None for an unlimited time.
always_resolve:
(BLE only, ignored for BR/EDR)
If True, always initiate a scan, resolving addresses, and connect to the
address that resolves to `peer_address`.
own_address_type: (BLE only)
'''
# Check parameters
@@ -3079,19 +2994,11 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str):
try:
if transport == BT_LE_TRANSPORT and peer_address.endswith('@'):
peer_address = Address.from_string_for_transport(
peer_address[:-1], transport
)
always_resolve = True
logger.debug('forcing address resolution')
else:
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except (InvalidArgumentError, ValueError):
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(
peer_address, transport
@@ -3106,12 +3013,6 @@ class Device(CompositeEventEmitter):
assert isinstance(peer_address, Address)
if transport == BT_LE_TRANSPORT and always_resolve:
logger.debug('resolving address')
peer_address = await self.find_peer_by_identity_address(
peer_address
) # TODO: timeout
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
@@ -3613,26 +3514,15 @@ class Device(CompositeEventEmitter):
check_result=True,
)
async def transfer_periodic_sync(
self, connection: Connection, sync_handle: int, service_data: int = 0
) -> None:
return await self.send_command(
HCI_LE_Periodic_Advertising_Sync_Transfer_Command(
connection_handle=connection.handle,
service_data=service_data,
sync_handle=sync_handle,
),
check_result=True,
)
async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT):
"""
Scan for a peer with a given name and return its address.
Scan for a peer with a give name and return its address and transport
"""
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()
# Scan/inquire with event handlers to handle scan/inquiry results
def on_peer_found(address, ad_data):
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
if local_name is None:
@@ -3641,13 +3531,13 @@ class Device(CompositeEventEmitter):
if local_name.decode('utf-8') == name:
peer_address.set_result(address)
listener = None
handler = None
was_scanning = self.scanning
was_discovering = self.discovering
try:
if transport == BT_LE_TRANSPORT:
event_name = 'advertisement'
listener = self.on(
handler = self.on(
event_name,
lambda advertisement: on_peer_found(
advertisement.address, advertisement.data
@@ -3659,7 +3549,7 @@ class Device(CompositeEventEmitter):
elif transport == BT_BR_EDR_TRANSPORT:
event_name = 'inquiry_result'
listener = self.on(
handler = self.on(
event_name,
lambda address, class_of_device, eir_data, rssi: on_peer_found(
address, eir_data
@@ -3673,67 +3563,21 @@ class Device(CompositeEventEmitter):
return await self.abort_on('flush', peer_address)
finally:
if listener is not None:
self.remove_listener(event_name, listener)
if handler is not None:
self.remove_listener(event_name, handler)
if transport == BT_LE_TRANSPORT and not was_scanning:
await self.stop_scanning()
elif transport == BT_BR_EDR_TRANSPORT and not was_discovering:
await self.stop_discovery()
async def find_peer_by_identity_address(self, identity_address: Address) -> Address:
"""
Scan for a peer with a resolvable address that can be resolved to a given
identity address.
"""
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()
def on_peer_found(address, _):
if address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching public address found for {address}')
peer_address.set_result(address)
return
if address.is_resolvable:
resolved_address = self.address_resolver.resolve(address)
if resolved_address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching identity found for {address}')
peer_address.set_result(address)
return
was_scanning = self.scanning
event_name = 'advertisement'
listener = None
try:
listener = self.on(
event_name,
lambda advertisement: on_peer_found(
advertisement.address, advertisement.data
),
)
if not self.scanning:
await self.start_scanning(filter_duplicates=True)
return await self.abort_on('flush', peer_address)
finally:
if listener is not None:
self.remove_listener(event_name, listener)
if not was_scanning:
await self.stop_scanning()
@property
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter
def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
self, pairing_config_factory: Callable[[Connection], PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory
@@ -3853,7 +3697,6 @@ class Device(CompositeEventEmitter):
if self.keystore is None:
raise InvalidOperationError('no key store')
logger.debug(f'Looking up key for {connection.peer_address}')
keys = await self.keystore.get(str(connection.peer_address))
if keys is None:
raise InvalidOperationError('keys not found in key store')
@@ -4289,12 +4132,6 @@ class Device(CompositeEventEmitter):
else self.public_address
)
if advertising_set.advertising_parameters.own_address_type in (
OwnAddressType.RANDOM,
OwnAddressType.PUBLIC,
):
connection.self_resolvable_address = None
# Setup auto-restart of the advertising set if needed.
if advertising_set.auto_restart:
connection.once(
@@ -4338,15 +4175,6 @@ class Device(CompositeEventEmitter):
role: int,
connection_parameters: ConnectionParameters,
) -> None:
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if (
peer_resolvable_address == Address.ANY_RANDOM
or not peer_address.is_resolved
):
peer_resolvable_address = None
logger.debug(
f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
@@ -4378,7 +4206,6 @@ class Device(CompositeEventEmitter):
peer_address = resolved_address
self_address = None
own_address_type: Optional[int] = None
if role == HCI_CENTRAL_ROLE:
own_address_type = self.connect_own_address_type
assert own_address_type is not None
@@ -4407,10 +4234,11 @@ class Device(CompositeEventEmitter):
else self.random_address
)
# Some controllers may return local resolvable address even not using address
# generation offloading. Ignore the value to prevent SMP failure.
if own_address_type in (OwnAddressType.RANDOM, OwnAddressType.PUBLIC):
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if peer_resolvable_address == Address.ANY_RANDOM:
peer_resolvable_address = None
# Create a connection.
connection = Connection(

View File

@@ -301,8 +301,6 @@ class Driver(common.Driver):
fw_name: str = ""
config_name: str = ""
POST_RESET_DELAY: float = 0.2
DRIVER_INFOS = [
# 8723A
DriverInfo(
@@ -497,24 +495,12 @@ class Driver(common.Driver):
@classmethod
async def driver_info_for_host(cls, host):
try:
await host.send_command(
HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True
command = HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
logger.error("failed to probe local version information")
return None
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
local_version = response.return_parameters
logger.debug(

View File

@@ -39,7 +39,7 @@ from typing import (
)
from bumble.colors import color
from bumble.core import BaseBumbleError, UUID
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
@@ -238,22 +238,22 @@ GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Call Friendly Name')
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
@@ -275,11 +275,6 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# Hearing Access Service
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC = UUID.from_16_bits(0x2BDC, 'Active Preset Index')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -325,11 +320,6 @@ def show_services(services: Iterable[Service]) -> None:
print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
class InvalidServiceError(BaseBumbleError):
"""The service is not compliant with the spec/profile"""
# -----------------------------------------------------------------------------
class Service(Attribute):
'''
@@ -345,7 +335,7 @@ class Service(Attribute):
uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: Iterable[Service] = (),
included_services: List[Service] = [],
) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
@@ -361,7 +351,7 @@ class Service(Attribute):
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = list(included_services)
self.included_services = included_services[:]
self.characteristics = characteristics[:]
self.primary = primary
@@ -395,7 +385,7 @@ class TemplateService(Service):
self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: Iterable[Service] = (),
included_services: List[Service] = [],
) -> None:
super().__init__(self.UUID, characteristics, primary, included_services)

View File

@@ -68,7 +68,7 @@ from .att import (
ATT_Error,
)
from . import core
from .core import UUID, InvalidStateError
from .core import UUID, InvalidStateError, ProtocolError
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -253,7 +253,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: Type[TemplateService]
@classmethod
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -283,8 +283,6 @@ class Client:
self.services = []
self.cached_values = {}
connection.on('disconnection', self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
@@ -345,7 +343,12 @@ class Client:
self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
@@ -402,7 +405,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -931,7 +934,12 @@ class Client:
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
@@ -953,7 +961,12 @@ class Client:
ATT_INVALID_OFFSET_ERROR,
):
break
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
part = response.part_attribute_value
attribute_value += part
@@ -964,6 +977,7 @@ class Client:
offset += len(part)
self.cache_value(attribute_handle, attribute_value)
# Return the value as bytes
return attribute_value
@@ -1046,7 +1060,12 @@ class Client:
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
else:
await self.send_command(
ATT_Write_Command(
@@ -1054,10 +1073,6 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'

View File

@@ -915,7 +915,7 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
# Check that the attribute exists
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
@@ -942,19 +942,11 @@ class Server(EventEmitter):
)
return
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
# Done
response = ATT_Write_Response()
self.send_response(connection, response)
# Accept the value
await attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):

View File

@@ -267,19 +267,6 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
HCI_LE_CS_CONFIG_COMPLETE_EVENT = 0x2F
HCI_LE_CS_PROCEDURE_ENABLE_EVENT = 0x30
HCI_LE_CS_SUBEVENT_RESULT_EVENT = 0x31
HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT = 0x32
HCI_LE_CS_TEST_END_COMPLETE_EVENT = 0x33
HCI_LE_MONITORED_ADVERTISERS_REPORT_EVENT = 0x34
HCI_LE_FRAME_SPACE_UPDATE_EVENT = 0x35
# HCI Command
@@ -586,36 +573,11 @@ HCI_LE_SET_DATA_RELATED_ADDRESS_CHANGES_COMMAND = hci_c
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D)
HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E)
HCI_LE_SET_EXTENDED_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x007F)
HCI_LE_SET_DECISION_DATA_COMMAND = hci_command_op_code(0x08, 0x0080)
HCI_LE_SET_DECISION_INSTRUCTIONS_COMMAND = hci_command_op_code(0x08, 0x0081)
HCI_LE_SET_PERIODIC_ADVERTISING_SUBEVENT_DATA_COMMAND = hci_command_op_code(0x08, 0x0082)
HCI_LE_SET_PERIODIC_ADVERTISING_RESPONSE_DATA_COMMAND = hci_command_op_code(0x08, 0x0083)
HCI_LE_SET_PERIODIC_SYNC_SUBEVENT_COMMAND = hci_command_op_code(0x08, 0x0084)
HCI_LE_EXTENDED_CREATE_CONNECTION_V2_COMMAND = hci_command_op_code(0x08, 0x0085)
HCI_LE_SET_PERIODIC_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x0086)
HCI_LE_READ_ALL_LOCAL_SUPPORTED_FEATURES_COMMAND = hci_command_op_code(0x08, 0x0087)
HCI_LE_READ_ALL_REMOTE_FEATURES_COMMAND = hci_command_op_code(0x08, 0x0088)
HCI_LE_CS_READ_LOCAL_SUPPORTED_CAPABILITIES_COMMAND = hci_command_op_code(0x08, 0x0089)
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMMAND = hci_command_op_code(0x08, 0x008A)
HCI_LE_CS_WRITE_CACHED_REMOTE_SUPPORTED_CAPABILITIES = hci_command_op_code(0x08, 0x008B)
HCI_LE_CS_SECURITY_ENABLE_COMMAND = hci_command_op_code(0x08, 0x008C)
HCI_LE_CS_SET_DEFAULT_SETTINGS_COMMAND = hci_command_op_code(0x08, 0x008D)
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMMAND = hci_command_op_code(0x08, 0x008E)
HCI_LE_CS_WRITE_CACHED_REMOTE_FAE_TABLE_COMMAND = hci_command_op_code(0x08, 0x008F)
HCI_LE_CS_CREATE_CONFIG_COMMAND = hci_command_op_code(0x08, 0x0090)
HCI_LE_CS_REMOVE_CONFIG_COMMAND = hci_command_op_code(0x08, 0x0091)
HCI_LE_CS_SET_CHANNEL_CLASSIFICATION_COMMAND = hci_command_op_code(0x08, 0x0092)
HCI_LE_CS_SET_PROCEDURE_PARAMETERS_COMMAND = hci_command_op_code(0x08, 0x0093)
HCI_LE_CS_PROCEDURE_ENABLE_COMMAND = hci_command_op_code(0x08, 0x0094)
HCI_LE_CS_TEST_COMMAND = hci_command_op_code(0x08, 0x0095)
HCI_LE_CS_TEST_END_COMMAND = hci_command_op_code(0x08, 0x0096)
HCI_LE_SET_HOST_FEATURE_V2_COMMAND = hci_command_op_code(0x08, 0x0097)
HCI_LE_ADD_DEVICE_TO_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x0098)
HCI_LE_REMOVE_DEVICE_FROM_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x0099)
HCI_LE_CLEAR_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x009A)
HCI_LE_READ_MONITORED_ADVERTISERS_LIST_SIZE_COMMAND = hci_command_op_code(0x08, 0x009B)
HCI_LE_ENABLE_MONITORING_ADVERTISERS_COMMAND = hci_command_op_code(0x08, 0x009C)
HCI_LE_FRAME_SPACE_UPDATE_COMMAND = hci_command_op_code(0x08, 0x009D)
# HCI Error Codes
@@ -1188,16 +1150,8 @@ class LeFeature(OpenIntEnum):
CHANNEL_CLASSIFICATION = 39
ADVERTISING_CODING_SELECTION = 40
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 41
DECISION_BASED_ADVERTISING_FILTERING = 42
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 43
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 44
UNSEGMENTED_FRAMED_MODE = 45
CHANNEL_SOUNDING = 46
CHANNEL_SOUNDING_HOST_SUPPORT = 47
CHANNEL_SOUNDING_TONE_QUALITY_INDICATION = 48
LL_EXTENDED_FEATURE_SET = 63
MONITORING_ADVERTISERS = 64
FRAME_SPACE_UPDATE = 65
class LeFeatureMask(enum.IntFlag):
LE_ENCRYPTION = 1 << LeFeature.LE_ENCRYPTION
@@ -1242,16 +1196,8 @@ class LeFeatureMask(enum.IntFlag):
CHANNEL_CLASSIFICATION = 1 << LeFeature.CHANNEL_CLASSIFICATION
ADVERTISING_CODING_SELECTION = 1 << LeFeature.ADVERTISING_CODING_SELECTION
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 1 << LeFeature.ADVERTISING_CODING_SELECTION_HOST_SUPPORT
DECISION_BASED_ADVERTISING_FILTERING = 1 << LeFeature.DECISION_BASED_ADVERTISING_FILTERING
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER
UNSEGMENTED_FRAMED_MODE = 1 << LeFeature.UNSEGMENTED_FRAMED_MODE
CHANNEL_SOUNDING = 1 << LeFeature.CHANNEL_SOUNDING
CHANNEL_SOUNDING_HOST_SUPPORT = 1 << LeFeature.CHANNEL_SOUNDING_HOST_SUPPORT
CHANNEL_SOUNDING_TONE_QUALITY_INDICATION = 1 << LeFeature.CHANNEL_SOUNDING_TONE_QUALITY_INDICATION
LL_EXTENDED_FEATURE_SET = 1 << LeFeature.LL_EXTENDED_FEATURE_SET
MONITORING_ADVERTISERS = 1 << LeFeature.MONITORING_ADVERTISERS
FRAME_SPACE_UPDATE = 1 << LeFeature.FRAME_SPACE_UPDATE
class LmpFeature(enum.IntEnum):
# Page 0 (Legacy LMP features)
@@ -1619,16 +1565,12 @@ class HCI_Object:
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
# Set fields first, because item_count might be 0.
for sub_field_name, _ in field:
result[sub_field_name] = []
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
result[sub_field_name].append(value)
result.setdefault(sub_field_name, []).append(value)
offset += size
continue
@@ -3040,27 +2982,6 @@ class HCI_Write_Inquiry_Scan_Activity_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('authentication_enable', 1),
]
)
class HCI_Read_Authentication_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.23 Read Authentication Enable Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('authentication_enable', 1)])
class HCI_Write_Authentication_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.24 Write Authentication Enable Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
@@ -3101,12 +3022,7 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('synchronous_flow_control_enable', 1),
]
)
@HCI_Command.command()
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
@@ -3275,13 +3191,7 @@ class HCI_Set_Event_Mask_Page_2_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('le_supported_host', 1),
('unused', 1),
]
)
@HCI_Command.command()
class HCI_Read_LE_Host_Support_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.78 Read LE Host Support Command
@@ -3414,39 +3324,13 @@ class HCI_Read_BD_ADDR_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
("status", STATUS_SPEC),
[("standard_codec_ids", 1)],
[("vendor_specific_codec_ids", 4)],
]
)
@HCI_Command.command()
class HCI_Read_Local_Supported_Codecs_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.8 Read Local Supported Codecs Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
("status", STATUS_SPEC),
[("standard_codec_ids", 1), ("standard_codec_transports", 1)],
[("vendor_specific_codec_ids", 4), ("vendor_specific_codec_transports", 1)],
]
)
class HCI_Read_Local_Supported_Codecs_V2_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.8 Read Local Supported Codecs Command
'''
class Transport(enum.IntFlag):
BR_EDR_ACL = 1 << 0
BR_EDR_SCO = 1 << 1
LE_CIS = 1 << 2
LE_BIS = 1 << 3
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle', 2)],
@@ -3604,12 +3488,7 @@ class HCI_LE_Set_Advertising_Parameters_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('tx_power_level', 1),
]
)
@HCI_Command.command()
class HCI_LE_Read_Advertising_Physical_Channel_Tx_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.6 LE Read Advertising Physical Channel Tx Power Command
@@ -3733,12 +3612,7 @@ class HCI_LE_Create_Connection_Cancel_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('filter_accept_list_size', 1),
]
)
@HCI_Command.command()
class HCI_LE_Read_Filter_Accept_List_Size_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.14 LE Read Filter Accept List Size Command
@@ -3849,12 +3723,7 @@ class HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('le_states', 8),
]
)
@HCI_Command.command()
class HCI_LE_Read_Supported_States_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.27 LE Read Supported States Command
@@ -4660,6 +4529,18 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# -----------------------------------------------------------------------------
@HCI_Command.command(
[
@@ -4695,32 +4576,6 @@ class HCI_LE_Set_Privacy_Mode_Command(HCI_Command):
return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode)
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('connection_handle', 2),
],
)
class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
@@ -4829,102 +4684,6 @@ class HCI_LE_Reject_CIS_Request_Command(HCI_Command):
reason: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('advertising_handle', 1),
('num_bis', 1),
('sdu_interval', 3),
('max_sdu', 2),
('max_transport_latency', 2),
('rtn', 1),
('phy', 1),
('packing', 1),
('framing', 1),
('encryption', 1),
('broadcast_code', 16),
],
)
class HCI_LE_Create_BIG_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.103 LE Create BIG command
'''
big_handle: int
advertising_handle: int
num_bis: int
sdu_interval: int
max_sdu: int
max_transport_latency: int
rtn: int
phy: int
packing: int
framing: int
encryption: int
broadcast_code: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('reason', {'size': 1, 'mapper': HCI_Constant.error_name}),
],
)
class HCI_LE_Terminate_BIG_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.105 LE Terminate BIG command
'''
big_handle: int
reason: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('sync_handle', 2),
('encryption', 1),
('broadcast_code', 16),
('mse', 1),
('big_sync_timeout', 2),
[('bis', 1)],
],
)
class HCI_LE_BIG_Create_Sync_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.106 LE BIG Create Sync command
'''
big_handle: int
sync_handle: int
encryption: int
broadcast_code: int
mse: int
big_sync_timeout: int
bis: List[int]
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
],
return_parameters_fields=[
('status', STATUS_SPEC),
('big_handle', 2),
],
)
class HCI_LE_BIG_Terminate_Sync_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.107. LE BIG Terminate Sync command
'''
big_handle: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
@@ -5760,27 +5519,6 @@ class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
'''
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('service_data', 2),
('sync_handle', 2),
('advertising_sid', 1),
('advertiser_address_type', Address.ADDRESS_TYPE_SPEC),
('advertiser_address', Address.parse_address_preceded_by_type),
('advertiser_phy', 1),
('periodic_advertising_interval', 2),
('advertiser_clock_accuracy', 1),
]
)
class HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event(HCI_LE_Meta_Event):
'''
See Bluetooth spec @ 7.7.65.24 LE Periodic Advertising Sync Transfer Received Event
'''
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
@@ -6065,32 +5803,6 @@ class HCI_Read_Remote_Version_Information_Complete_Event(HCI_Event):
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('unused', 1),
(
'service_type',
{
'size': 1,
'mapper': lambda x: HCI_QOS_Setup_Complete_Event.ServiceType(x).name,
},
),
]
)
class HCI_QOS_Setup_Complete_Event(HCI_Event):
'''
See Bluetooth spec @ 7.7.13 QoS Setup Complete Event
'''
class ServiceType(OpenIntEnum):
NO_TRAFFIC_AVAILABLE = 0x00
BEST_EFFORT_AVAILABLE = 0x01
GUARANTEED_AVAILABLE = 0x02
# -----------------------------------------------------------------------------
@HCI_Event.event(
[
@@ -6499,23 +6211,6 @@ class HCI_Synchronous_Connection_Changed_Event(HCI_Event):
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('max_tx_latency', 2),
('max_rx_latency', 2),
('min_remote_timeout', 2),
('min_local_timeout', 2),
]
)
class HCI_Sniff_Subrating_Event(HCI_Event):
'''
See Bluetooth spec @ 7.7.37 Sniff Subrating Event
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[

View File

@@ -795,32 +795,29 @@ class HfProtocol(pyee.EventEmitter):
# Append to the read buffer.
self.read_buffer.extend(data)
while self.read_buffer:
# Locate header and trailer.
header = self.read_buffer.find(b'\r\n')
trailer = self.read_buffer.find(b'\r\n', header + 2)
if header == -1 or trailer == -1:
return
# Locate header and trailer.
header = self.read_buffer.find(b'\r\n')
trailer = self.read_buffer.find(b'\r\n', header + 2)
if header == -1 or trailer == -1:
return
# Isolate the AT response code and parameters.
raw_response = self.read_buffer[header + 2 : trailer]
response = AtResponse.parse_from(raw_response)
logger.debug(f"<<< {raw_response.decode()}")
# Isolate the AT response code and parameters.
raw_response = self.read_buffer[header + 2 : trailer]
response = AtResponse.parse_from(raw_response)
logger.debug(f"<<< {raw_response.decode()}")
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 2 :]
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(
f"dropping unexpected response with code '{response.code}'"
)
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(f"dropping unexpected response with code '{response.code}'")
async def execute_command(
self,
@@ -1247,32 +1244,31 @@ class AgProtocol(pyee.EventEmitter):
# Append to the read buffer.
self.read_buffer.extend(data)
while self.read_buffer:
# Locate the trailer.
trailer = self.read_buffer.find(b'\r')
if trailer == -1:
return
# Locate the trailer.
trailer = self.read_buffer.find(b'\r')
if trailer == -1:
return
# Isolate the AT response code and parameters.
raw_command = self.read_buffer[:trailer]
command = AtCommand.parse_from(raw_command)
logger.debug(f"<<< {raw_command.decode()}")
# Isolate the AT response code and parameters.
raw_command = self.read_buffer[:trailer]
command = AtCommand.parse_from(raw_command)
logger.debug(f"<<< {raw_command.decode()}")
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 1 :]
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 1 :]
if command.sub_code == AtCommand.SubCode.TEST:
handler_name = f'_on_{command.code.lower()}_test'
elif command.sub_code == AtCommand.SubCode.READ:
handler_name = f'_on_{command.code.lower()}_read'
else:
handler_name = f'_on_{command.code.lower()}'
if command.sub_code == AtCommand.SubCode.TEST:
handler_name = f'_on_{command.code.lower()}_test'
elif command.sub_code == AtCommand.SubCode.READ:
handler_name = f'_on_{command.code.lower()}_read'
else:
handler_name = f'_on_{command.code.lower()}'
if handler := getattr(self, handler_name, None):
handler(*command.parameters)
else:
logger.warning('Handler %s not found', handler_name)
self.send_response('ERROR')
if handler := getattr(self, handler_name, None):
handler(*command.parameters)
else:
logger.warning('Handler %s not found', handler_name)
self.send_response('ERROR')
def send_response(self, response: str) -> None:
"""Sends an AT response."""

View File

@@ -23,12 +23,13 @@ import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable
from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from bumble.hci import Address
from .hci import Address
# -----------------------------------------------------------------------------
@@ -219,27 +220,31 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
@@ -329,18 +334,17 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
@dataclass
class GetSetStatus:
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -406,6 +410,7 @@ class Device(HID):
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
@@ -423,9 +428,7 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
@@ -439,6 +442,7 @@ class Device(HID):
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -449,7 +453,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
@@ -460,12 +464,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
@@ -475,14 +480,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")

View File

@@ -171,7 +171,7 @@ class Host(AbortableEventEmitter):
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None
self.pending_response = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
@@ -514,9 +514,7 @@ class Host(AbortableEventEmitter):
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: Optional[int] = None
):
async def send_command(self, command, check_result=False):
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -528,13 +526,12 @@ class Host(AbortableEventEmitter):
try:
self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
response = self.pending_response.result()
response = await self.pending_response
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
status = response.status
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
@@ -628,21 +625,14 @@ class Host(AbortableEventEmitter):
# Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet: bytes) -> None:
try:
hci_packet = hci.HCI_Packet.from_bytes(packet)
except Exception as error:
logger.warning(f'!!! error parsing packet from bytes: {error}')
return
hci_packet = hci.HCI_Packet.from_bytes(packet)
if self.ready or (
isinstance(hci_packet, hci.HCI_Command_Complete_Event)
and hci_packet.command_opcode == hci.HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
logger.debug(
f'reset not done, ignoring packet from controller: {hci_packet}'
)
logger.debug('reset not done, ignoring packet from controller')
def on_transport_lost(self):
# Called by the source when the transport has been lost.
@@ -1106,18 +1096,6 @@ class Host(AbortableEventEmitter):
event.status,
)
def on_hci_qos_setup_complete_event(self, event):
if event.status == hci.HCI_SUCCESS:
self.emit(
'connection_qos_setup', event.connection_handle, event.service_type
)
else:
self.emit(
'connection_qos_setup_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
pass

View File

@@ -25,10 +25,8 @@ import grpc.aio
from .config import Config
from .device import PandoraDevice
from .host import HostService
from .l2cap import L2CAPService
from .security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
@@ -79,7 +77,6 @@ async def serve(
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
# call hooks if any.
for hook in _SERVICERS_HOOKS:

View File

@@ -1,310 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import grpc
import json
import logging
from asyncio import Queue as AsyncQueue, Future
from . import utils
from .config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
ClassicChannelServer,
ClassicChannelSpec,
LeCreditBasedChannel,
LeCreditBasedChannelServer,
LeCreditBasedChannelSpec,
)
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
Channel as PandoraChannel,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
DisconnectRequest,
DisconnectResponse,
ReceiveRequest,
ReceiveResponse,
SendRequest,
SendResponse,
WaitConnectionRequest,
WaitConnectionResponse,
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
class ChannelContext:
close_future: Future
sdu_queue: AsyncQueue
class L2CAPService(L2CAPServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'L2CAP', 'device': device}
)
self.device = device
self.config = config
self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future()
sdu_queue: AsyncQueue = AsyncQueue()
def on_channel_sdu(sdu):
sdu_queue.put_nowait(sdu)
def on_close():
close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close)
return ChannelContext(close_future, sdu_queue)
@utils.rpc
async def WaitConnection(
self, request: WaitConnectionRequest, context: grpc.ServicerContext
) -> WaitConnectionResponse:
self.log.debug('WaitConnection')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
if channel_type.spsm in self.device.l2cap_channel_manager.le_coc_servers:
l2cap_server = self.device.l2cap_channel_manager.le_coc_servers[
channel_type.spsm
]
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
if channel_type.psm in self.device.l2cap_channel_manager.servers:
l2cap_server = self.device.l2cap_channel_manager.servers[
channel_type.psm
]
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
channel_future: Future[PandoraChannel] = (
asyncio.get_running_loop().create_future()
)
def on_l2cap_channel(l2cap_channel: L2capChannel):
try:
channel_context = self.register_event(l2cap_channel)
pandora_channel: PandoraChannel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
channel_future.set_result(pandora_channel)
except Exception as e:
self.log.error(f'Failed to set channel future: {e}')
if l2cap_server is None:
l2cap_server = self.device.create_l2cap_server(
spec=spec, handler=on_l2cap_channel
)
else:
l2cap_server.on('connection', on_l2cap_channel)
try:
self.log.debug('Waiting for a channel connection.')
pandora_channel: PandoraChannel = await channel_future
return WaitConnectionResponse(channel=pandora_channel)
except Exception as e:
self.log.warning(f'Exception: {e}')
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def WaitDisconnection(
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> WaitDisconnectionResponse:
try:
self.log.debug('WaitDisconnection')
await self.lookup_context(request.channel).close_future
self.log.debug("return WaitDisconnectionResponse")
return WaitDisconnectionResponse(success=empty_pb2.Empty())
except KeyError as e:
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
except Exception as e:
self.log.exception(f'WaitDisonnection failed: {e}')
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Receive(
self, request: ReceiveRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ReceiveResponse, None]:
self.log.debug('Receive')
oneof = request.WhichOneof('source')
self.log.debug(f'Source: {oneof}.')
pandora_channel = getattr(request, oneof)
sdu_queue = self.lookup_context(pandora_channel).sdu_queue
while sdu := await sdu_queue.get():
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
response = ReceiveResponse(data=sdu)
yield response
@utils.rpc
async def Connect(
self, request: ConnectRequest, context: grpc.ServicerContext
) -> ConnectResponse:
self.log.debug('Connect')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
try:
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
l2cap_channel = await connection.create_l2cap_channel(spec=spec)
channel_context = self.register_event(l2cap_channel)
pandora_channel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
return ConnectResponse(channel=pandora_channel)
except OutOfResourcesError as e:
self.log.error(e)
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
except InvalidArgumentError as e:
self.log.error(e)
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> DisconnectResponse:
try:
self.log.debug('Disconnect')
l2cap_channel = self.lookup_channel(request.channel)
if not l2cap_channel:
self.log.warning('Disconnect: Unable to find the channel')
return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
await l2cap_channel.disconnect()
return DisconnectResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return DisconnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Send(
self, request: SendRequest, context: grpc.ServicerContext
) -> SendResponse:
self.log.debug('Send')
try:
oneof = request.WhichOneof('sink')
self.log.debug(f'Sink: {oneof}.')
pandora_channel = getattr(request, oneof)
l2cap_channel = self.lookup_channel(pandora_channel)
if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.send_pdu(request.data)
else:
l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
def craft_pandora_channel(
self,
connection_handle: int,
l2cap_channel: L2capChannel,
) -> PandoraChannel:
parameters = {
"connection_handle": connection_handle,
"source_cid": l2cap_channel.source_cid,
}
cookie = any_pb2.Any()
cookie.value = json.dumps(parameters).encode()
return PandoraChannel(cookie=cookie)
def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
(connection_handle, source_cid) = json.loads(
pandora_channel.cookie.value
).values()
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
return self.channels[pandora_channel.cookie.value]

View File

@@ -1,520 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Audio Input Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import gatt
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
DelegatedCharacteristicAdapter,
TemplateService,
CharacteristicValue,
PackedCharacteristicAdapter,
GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
CHANGE_COUNTER_MAX_VALUE = 0xFF
GAIN_SETTINGS_MIN_VALUE = 0
GAIN_SETTINGS_MAX_VALUE = 255
class ErrorCode(OpenIntEnum):
'''
Cf. 1.6 Application error codes
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
MUTE_DISABLED = 0x82
VALUE_OUT_OF_RANGE = 0x83
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
class Mute(OpenIntEnum):
'''
Cf. 2.2.1.2 Mute Field
'''
NOT_MUTED = 0x00
MUTED = 0x01
DISABLED = 0x02
class GainMode(OpenIntEnum):
'''
Cf. 2.2.1.3 Gain Mode
'''
MANUAL_ONLY = 0x00
AUTOMATIC_ONLY = 0x01
MANUAL = 0x02
AUTOMATIC = 0x03
class AudioInputStatus(OpenIntEnum):
'''
Cf. 3.4 Audio Input Status
'''
INATIVE = 0x00
ACTIVE = 0x01
class AudioInputControlPointOpCode(OpenIntEnum):
'''
Cf. 3.5.1 Audio Input Control Point procedure requirements
'''
SET_GAIN_SETTING = 0x00
UNMUTE = 0x02
MUTE = 0x03
SET_MANUAL_GAIN_MODE = 0x04
SET_AUTOMATIC_GAIN_MODE = 0x05
# -----------------------------------------------------------------------------
@dataclass
class AudioInputState:
'''
Cf. 2.2.1 Audio Input State
'''
gain_settings: int = 0
mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return bytes(
[self.gain_settings, self.mute, self.gain_mode, self.change_counter]
)
@classmethod
def from_bytes(cls, data: bytes):
gain_settings, mute, gain_mode, change_counter = struct.unpack("BBBB", data)
return cls(gain_settings, mute, gain_mode, change_counter)
def update_gain_settings_unit(self, gain_settings_unit: int) -> None:
self.gain_settings_unit = gain_settings_unit
def increment_gain_settings(self, gain_settings_unit: int) -> None:
self.gain_settings += gain_settings_unit
self.increment_change_counter()
def decrement_gain_settings(self) -> None:
self.gain_settings -= self.gain_settings_unit
self.increment_change_counter()
def increment_change_counter(self):
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=bytes(self)
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class GainSettingsProperties:
'''
Cf. 3.2 Gain Settings Properties
'''
gain_settings_unit: int = 1
gain_settings_minimum: int = GAIN_SETTINGS_MIN_VALUE
gain_settings_maximum: int = GAIN_SETTINGS_MAX_VALUE
@classmethod
def from_bytes(cls, data: bytes):
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
struct.unpack('BBB', data)
)
GainSettingsProperties(
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
)
def __bytes__(self) -> bytes:
return bytes(
[
self.gain_settings_unit,
self.gain_settings_minimum,
self.gain_settings_maximum,
]
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class AudioInputControlPoint:
'''
Cf. 3.5.2 Audio Input Control Point
'''
audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
gain_settings_operand = value[2]
await self._set_gain_settings(connection, gain_settings_operand)
elif opcode == AudioInputControlPointOpCode.UNMUTE:
await self._unmute(connection)
elif opcode == AudioInputControlPointOpCode.MUTE:
change_counter_operand = value[1]
await self._mute(connection, change_counter_operand)
elif opcode == AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE:
await self._set_manual_gain_mode(connection)
elif opcode == AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE:
await self._set_automatic_gain_mode(connection)
else:
logger.error(f"OpCode value is incorrect: {opcode}")
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
async def _set_gain_settings(
self, connection: Connection, gain_settings_operand: int
) -> None:
'''Cf. 3.5.2.1 Set Gain Settings Procedure'''
gain_mode = self.audio_input_state.gain_mode
logger.error(f"set_gain_setting: gain_mode: {gain_mode}")
if not (gain_mode == GainMode.MANUAL or gain_mode == GainMode.MANUAL_ONLY):
logger.warning(
"GainMode should be either MANUAL or MANUAL_ONLY Cf Spec Audio Input Control Service 3.5.2.1"
)
return
if (
gain_settings_operand < self.gain_settings_properties.gain_settings_minimum
or gain_settings_operand
> self.gain_settings_properties.gain_settings_maximum
):
logger.error("gain_seetings value out of range")
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
if self.audio_input_state.gain_settings != gain_settings_operand:
self.audio_input_state.gain_settings = gain_settings_operand
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _unmute(self, connection: Connection):
'''Cf. 3.5.2.2 Unmute procedure'''
logger.error(f'unmute: {self.audio_input_state.mute}')
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("unmute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if mute == Mute.NOT_MUTED:
return
self.audio_input_state.mute = Mute.NOT_MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _mute(self, connection: Connection, change_counter_operand: int) -> None:
'''Cf. 3.5.5.2 Mute procedure'''
change_counter = self.audio_input_state.change_counter
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("mute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if mute == Mute.MUTED:
return
self.audio_input_state.mute = Mute.MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_manual_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.4 Set Manual Gain Mode procedure'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.MANUAL:
return
self.audio_input_state.gain_mode = GainMode.MANUAL
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_automatic_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.5 Set Automatic Gain Mode'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.AUTOMATIC:
return
self.audio_input_state.gain_mode = GainMode.AUTOMATIC
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioInputDescription:
'''
Cf. 3.6 Audio Input Description
'''
audio_input_description: str = "Bluetooth"
attribute_value: Optional[CharacteristicValue] = None
@classmethod
def from_bytes(cls, data: bytes):
return cls(audio_input_description=data.decode('utf-8'))
def __bytes__(self) -> bytes:
return self.audio_input_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return self.audio_input_description.encode('utf-8')
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_input_description = value.decode('utf-8')
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value
)
class AICSService(TemplateService):
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
def __init__(
self,
audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_type: str = "local",
audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: Optional[AudioInputDescription] = None,
):
self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state
)
self.gain_settings_properties = (
GainSettingsProperties()
if gain_settings_properties is None
else gain_settings_properties
)
self.audio_input_status = (
AudioInputStatus.ACTIVE
if audio_input_status is None
else audio_input_status
)
self.audio_input_description = (
AudioInputDescription()
if audio_input_description is None
else audio_input_description
)
self.audio_input_control_point: AudioInputControlPoint = AudioInputControlPoint(
self.audio_input_state, self.gain_settings_properties
)
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.audio_input_state.on_read),
),
encode=lambda value: bytes(value),
)
self.audio_input_state.attribute_value = (
self.audio_input_state_characteristic.value
)
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.gain_settings_properties.on_read),
)
)
self.audio_input_type_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=audio_input_type,
)
self.audio_input_status_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([self.audio_input_status]),
)
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.audio_input_control_point.on_write
),
)
)
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.audio_input_description.on_write,
read=self.audio_input_description.on_read,
),
)
)
self.audio_input_description.attribute_value = (
self.audio_input_control_point_characteristic.value
)
super().__init__(
characteristics=[
self.audio_input_state_characteristic, # type: ignore
self.gain_settings_properties_characteristic, # type: ignore
self.audio_input_type_characteristic, # type: ignore
self.audio_input_status_characteristic, # type: ignore
self.audio_input_control_point_characteristic, # type: ignore
self.audio_input_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class AICSServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = AICSService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
self.audio_input_state = DelegatedCharacteristicAdapter(
characteristic=characteristics[0], decode=AudioInputState.from_bytes
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Gain Settings Attribute Characteristic not found"
)
self.gain_settings_properties = PackedCharacteristicAdapter(
characteristics[0],
'BBB',
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Status Characteristic not found"
)
self.audio_input_status = PackedCharacteristicAdapter(
characteristics[0],
'B',
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Control Point Characteristic not found"
)
self.audio_input_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Description Characteristic not found"
)
self.audio_input_description = characteristics[0]

View File

@@ -1,739 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Audio Stream Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import logging
import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
# -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# -----------------------------------------------------------------------------
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
# -----------------------------------------------------------------------------
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = (),
sink_ase_id: Sequence[int] = (),
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

View File

@@ -1,295 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import struct
import logging
from typing import List, Optional, Callable, Union, Any
from bumble import l2cap
from bumble import utils
from bumble import gatt
from bumble import gatt_client
from bumble.core import AdvertisingData
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
_logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class DeviceCapabilities(enum.IntFlag):
IS_RIGHT = 0x01
IS_DUAL = 0x02
CSIS_SUPPORTED = 0x04
class FeatureMap(enum.IntFlag):
LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED = 0x01
class AudioType(utils.OpenIntEnum):
UNKNOWN = 0x00
RINGTONE = 0x01
PHONE_CALL = 0x02
MEDIA = 0x03
class OpCode(utils.OpenIntEnum):
START = 1
STOP = 2
STATUS = 3
class Codec(utils.OpenIntEnum):
G_722_16KHZ = 1
class SupportedCodecs(enum.IntFlag):
G_722_16KHZ = 1 << Codec.G_722_16KHZ
class PeripheralStatus(utils.OpenIntEnum):
"""Status update on the other peripheral."""
OTHER_PERIPHERAL_DISCONNECTED = 1
OTHER_PERIPHERAL_CONNECTED = 2
CONNECTION_PARAMETER_UPDATED = 3
class AudioStatus(utils.OpenIntEnum):
"""Status report field for the audio control point."""
OK = 0
UNKNOWN_COMMAND = -1
ILLEGAL_PARAMETERS = -2
# -----------------------------------------------------------------------------
class AshaService(gatt.TemplateService):
UUID = gatt.GATT_ASHA_SERVICE
audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None
volume: Optional[int] = None
other_state: Optional[int] = None
connection: Optional[Connection] = None
def __init__(
self,
capability: int,
hisyncid: Union[List[int], bytes],
device: Device,
psm: int = 0,
audio_sink: Optional[Callable[[bytes], Any]] = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01,
render_delay_milliseconds: int = 0,
supported_codecs: int = SupportedCodecs.G_722_16KHZ,
) -> None:
if len(hisyncid) != 8:
_logger.warning('HiSyncId should have a length of 8, got %d', len(hisyncid))
self.hisyncid = bytes(hisyncid)
self.capability = capability
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
self.audio_sink = audio_sink
self.protocol_version = protocol_version
self.read_only_properties_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack(
"<BB8sBH2sH",
protocol_version,
capability,
self.hisyncid,
feature_map,
render_delay_milliseconds,
b'\x00\x00',
supported_codecs,
),
)
self.audio_control_point_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
)
self.audio_status_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.READABLE,
bytes([AudioStatus.OK]),
)
self.volume_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_volume_write),
)
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=self._on_connection,
).psm
self.le_psm_out_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self) -> bytes:
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(gatt.GATT_ASHA_SERVICE)
+ bytes([self.protocol_version, self.capability])
+ self.hisyncid[:4],
),
]
)
)
# Handler for audio control commands
async def _on_audio_control_point_write(
self, connection: Optional[Connection], value: bytes
) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == OpCode.START:
# Start
self.active_codec = Codec(value[1])
self.audio_type = AudioType(value[2])
self.volume = value[3]
self.other_state = value[4]
_logger.debug(
f'### START: codec={self.active_codec.name}, '
f'audio_type={self.audio_type.name}, '
f'volume={self.volume}, '
f'other_state={self.other_state}'
)
self.emit('started')
elif opcode == OpCode.STOP:
_logger.debug('### STOP')
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('stopped')
elif opcode == OpCode.STATUS:
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
if self.connection is None and connection:
self.connection = connection
def on_disconnection(_reason) -> None:
self.connection = None
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('disconnected')
connection.once('disconnection', on_disconnection)
# OPCODE_STATUS does not need audio status point update
if opcode != OpCode.STATUS:
await self.device.notify_subscribers(
self.audio_status_characteristic, force=True
)
# Handler for volume control
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0]
self.emit('volume_changed')
# Register an L2CAP CoC server
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
def on_data(data: bytes) -> None:
if self.audio_sink: # pylint: disable=not-callable
self.audio_sink(data)
channel.sink = on_data
# -----------------------------------------------------------------------------
class AshaServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AshaService
read_only_properties_characteristic: gatt_client.CharacteristicProxy
audio_control_point_characteristic: gatt_client.CharacteristicProxy
audio_status_point_characteristic: gatt_client.CharacteristicProxy
volume_characteristic: gatt_client.CharacteristicProxy
psm_characteristic: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
for uuid, attribute_name in (
(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
'read_only_properties_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
'audio_control_point_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
'audio_status_point_characteristic',
),
(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
'volume_characteristic',
),
(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
'psm_characteristic',
),
):
if not (
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
):
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
setattr(self, attribute_name, characteristics[0])

View File

@@ -0,0 +1,193 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from bumble import l2cap
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -24,12 +24,15 @@ import enum
import struct
import functools
import logging
from typing import List
from typing import Optional, List, Union, Type, Dict, Any, Tuple
from typing_extensions import Self
from bumble import core
from bumble import colors
from bumble import device
from bumble import hci
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from bumble.profiles import le_audio
@@ -248,6 +251,231 @@ class AnnouncementType(utils.OpenIntEnum):
TARGETED = 0x01
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
@dataclasses.dataclass
class UnicastServerAdvertisingData:
"""Advertising Data for ASCS."""
@@ -455,6 +683,54 @@ class CodecSpecificConfiguration:
)
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
@dataclasses.dataclass
class BroadcastAudioAnnouncement:
broadcast_id: int
@@ -546,3 +822,603 @@ class BasicAudioAnnouncement:
)
return cls(presentation_delay, subgroups)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = [],
sink_ase_id: Sequence[int] = [],
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

View File

@@ -1,440 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Broadcast Audio Scan Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import ClassVar, List, Optional, Sequence
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
from bumble import utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ApplicationError(utils.OpenIntEnum):
OPCODE_NOT_SUPPORTED = 0x80
INVALID_SOURCE_ID = 0x81
# -----------------------------------------------------------------------------
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
return bytes([len(subgroups)]) + b"".join(
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
+ subgroup.metadata
for subgroup in subgroups
)
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
num_subgroups = data[0]
offset = 1
subgroups = []
for _ in range(num_subgroups):
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
metadata_length = data[offset + 4]
metadata = data[offset + 5 : offset + 5 + metadata_length]
offset += 5 + metadata_length
subgroups.append(SubgroupInfo(bis_sync, metadata))
return subgroups
# -----------------------------------------------------------------------------
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
@dataclasses.dataclass
class SubgroupInfo:
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
bis_sync: int
metadata: bytes
class ControlPointOperation:
class OpCode(utils.OpenIntEnum):
REMOTE_SCAN_STOPPED = 0x00
REMOTE_SCAN_STARTED = 0x01
ADD_SOURCE = 0x02
MODIFY_SOURCE = 0x03
SET_BROADCAST_CODE = 0x04
REMOVE_SOURCE = 0x05
op_code: OpCode
parameters: bytes
@classmethod
def from_bytes(cls, data: bytes) -> ControlPointOperation:
op_code = data[0]
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
return RemoteScanStoppedOperation()
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
return RemoteScanStartedOperation()
if op_code == cls.OpCode.ADD_SOURCE:
return AddSourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.MODIFY_SOURCE:
return ModifySourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.SET_BROADCAST_CODE:
return SetBroadcastCodeOperation.from_parameters(data[1:])
if op_code == cls.OpCode.REMOVE_SOURCE:
return RemoveSourceOperation.from_parameters(data[1:])
raise core.InvalidArgumentError("invalid op code")
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
self.op_code = op_code
self.parameters = parameters
def __bytes__(self) -> bytes:
return bytes([self.op_code]) + self.parameters
class RemoteScanStoppedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
class RemoteScanStartedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
class AddSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
instance.parameters = parameters
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
parameters, 1
)[1]
instance.advertising_sid = parameters[7]
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
instance.subgroups = decode_subgroups(parameters[14:])
return instance
def __init__(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.ADD_SOURCE,
struct.pack(
"<B6sB3sBH",
advertiser_address.address_type,
bytes(advertiser_address),
advertising_sid,
broadcast_id.to_bytes(3, "little"),
pa_sync,
pa_interval,
)
+ encode_subgroups(subgroups),
)
self.advertiser_address = advertiser_address
self.advertising_sid = advertising_sid
self.broadcast_id = broadcast_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class ModifySourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
instance.subgroups = decode_subgroups(parameters[4:])
return instance
def __init__(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.MODIFY_SOURCE,
struct.pack("<BBH", source_id, pa_sync, pa_interval)
+ encode_subgroups(subgroups),
)
self.source_id = source_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class SetBroadcastCodeOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.broadcast_code = parameters[1:17]
return instance
def __init__(
self,
source_id: int,
broadcast_code: bytes,
) -> None:
super().__init__(
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
bytes([source_id]) + broadcast_code,
)
self.source_id = source_id
self.broadcast_code = broadcast_code
if len(self.broadcast_code) != 16:
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
class RemoveSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
return instance
def __init__(self, source_id: int) -> None:
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
self.source_id = source_id
@dataclasses.dataclass
class BroadcastReceiveState:
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
NOT_SYNCHRONIZED_TO_PA = 0x00
SYNCINFO_REQUEST = 0x01
SYNCHRONIZED_TO_PA = 0x02
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
NO_PAST = 0x04
class BigEncryption(utils.OpenIntEnum):
NOT_ENCRYPTED = 0x00
BROADCAST_CODE_REQUIRED = 0x01
DECRYPTING = 0x02
BAD_CODE = 0x03
source_id: int
source_address: hci.Address
source_adv_sid: int
broadcast_id: int
pa_sync_state: PeriodicAdvertisingSyncState
big_encryption: BigEncryption
bad_code: bytes
subgroups: List[SubgroupInfo]
@classmethod
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
if not data:
return None
source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8]
broadcast_id = int.from_bytes(data[9:12], "little")
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
big_encryption = cls.BigEncryption(data[13])
if big_encryption == cls.BigEncryption.BAD_CODE:
bad_code = data[14:30]
subgroups = decode_subgroups(data[30:])
else:
bad_code = b""
subgroups = decode_subgroups(data[14:])
return cls(
source_id,
source_address,
source_adv_sid,
broadcast_id,
pa_sync_state,
big_encryption,
bad_code,
subgroups,
)
def __bytes__(self) -> bytes:
return (
struct.pack(
"<BB6sB3sBB",
self.source_id,
self.source_address.address_type,
bytes(self.source_address),
self.source_adv_sid,
self.broadcast_id.to_bytes(3, "little"),
self.pa_sync_state,
self.big_encryption,
)
+ self.bad_code
+ encode_subgroups(self.subgroups)
)
# -----------------------------------------------------------------------------
class BroadcastAudioScanService(gatt.TemplateService):
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
def __init__(self):
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
write=self.on_broadcast_audio_scan_control_point_write
),
)
self.broadcast_receive_state_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
b"12", # TEST
)
super().__init__([self.battery_level_characteristic])
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
) -> None:
pass
# -----------------------------------------------------------------------------
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BroadcastAudioScanService
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [
gatt.DelegatedCharacteristicAdapter(
characteristic, decode=BroadcastReceiveState.from_bytes
)
for characteristic in characteristics
]
async def send_control_point_operation(
self, operation: ControlPointOperation
) -> None:
await self.broadcast_audio_scan_control_point.write_value(
bytes(operation), with_response=True
)
async def remote_scan_started(self) -> None:
await self.send_control_point_operation(RemoteScanStartedOperation())
async def remote_scan_stopped(self) -> None:
await self.send_control_point_operation(RemoteScanStoppedOperation())
async def add_source(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
AddSourceOperation(
advertiser_address,
advertising_sid,
broadcast_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def modify_source(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
ModifySourceOperation(
source_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def remove_source(self, source_id: int) -> None:
await self.send_control_point_operation(RemoveSourceOperation(source_id))

View File

@@ -1,110 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic Access Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Tuple, Union
from bumble.core import Appearance
from bumble.gatt import (
TemplateService,
Characteristic,
CharacteristicAdapter,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(TemplateService):
UUID = GATT_GENERIC_ACCESS_SERVICE
def __init__(
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
self.appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', appearance_int),
)
super().__init__(
[self.device_name_characteristic, self.appearance_characteristic]
)
# -----------------------------------------------------------------------------
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicAdapter]
appearance: Optional[DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC
):
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
else:
self.device_name = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_APPEARANCE_CHARACTERISTIC
):
self.appearance = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Appearance.from_int(
struct.unpack_from('<H', value, 0)[0],
),
)
else:
self.appearance = None

View File

@@ -1,674 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import functools
from bumble import att, gatt, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Device, Connection
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.hci import Address
from dataclasses import dataclass, field
import logging
from typing import Any, Dict, List, Optional, Set, Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ErrorCode(OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80
WRITE_NAME_NOT_ALLOWED = 0x81
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
PRESET_OPERATION_NOT_POSSIBLE = 0x83
INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00
MONAURAL_HEARING_AID = 0b01
BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off
READ_PRESETS_REQUEST = 0x01
READ_PRESET_RESPONSE = 0x02
PRESET_CHANGED = 0x03
WRITE_PRESET_NAME = 0x04
SET_ACTIVE_PRESET = 0x05
SET_NEXT_PRESET = 0x06
SET_PREVIOUS_PRESET = 0x07
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
@dataclass
class HearingAidFeatures:
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
hearing_aid_type: HearingAidType
preset_synchronization_support: PresetSynchronizationSupport
independent_presets: IndependentPresets
dynamic_presets: DynamicPresets
writable_presets_support: WritablePresetsSupport
def __bytes__(self) -> bytes:
return bytes(
[
(self.hearing_aid_type << 0)
| (self.preset_synchronization_support << 2)
| (self.independent_presets << 3)
| (self.dynamic_presets << 4)
| (self.writable_presets_support << 5)
]
)
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
return HearingAidFeatures(
HearingAidType(data & 0b11),
PresetSynchronizationSupport(data >> 2 & 0b1),
IndependentPresets(data >> 3 & 0b1),
DynamicPresets(data >> 4 & 0b1),
WritablePresetsSupport(data >> 5 & 0b1),
)
@dataclass
class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(OpenIntEnum):
# fmt: off
GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01
PRESET_RECORD_AVAILABLE = 0x02
PRESET_RECORD_UNAVAILABLE = 0x03
@dataclass
class Generic:
prev_index: int
preset_record: PresetRecord
def __bytes__(self) -> bytes:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
additional_parameters_bytes = bytes(self.additional_parameters)
else:
additional_parameters_bytes = bytes([self.additional_parameters])
return (
bytes(
[
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
self.change_id,
is_last,
]
)
+ additional_parameters_bytes
)
class PresetChangedOperationDeleted(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
self.additional_parameters = index
class PresetChangedOperationAvailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
self.additional_parameters = index
class PresetChangedOperationUnavailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
self.additional_parameters = index
@dataclass
class PresetRecord:
'''See Hearing Access Service 2.8. Preset record.'''
@dataclass
class Property:
class Writable(OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1
class IsAvailable(OpenIntEnum):
IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1
writable: Writable = Writable.CAN_BE_WRITTEN
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
def __bytes__(self) -> bytes:
return bytes([self.writable | (self.is_available << 1)])
index: int
name: str
properties: Property = field(default_factory=Property)
def __bytes__(self) -> bytes:
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
def is_available(self) -> bool:
return (
self.properties.is_available
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic
hearing_aid_preset_control_point: gatt.Characteristic
active_preset_index_characteristic: gatt.Characteristic
active_preset_index: int
active_preset_index_per_device: Dict[Address, int]
device: Device
server_features: HearingAidFeatures
preset_records: Dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
preset_changed_operations_history_per_device: Dict[
Address, List[PresetChangedOperation]
]
# Keep an updated list of connected client to send notification to
currently_connected_clients: Set[Connection]
def __init__(
self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
) -> None:
self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False
self.preset_changed_operations_history_per_device = {}
self.currently_connected_clients = set()
self.device = device
self.server_features = features
if len(presets) < 1:
raise InvalidArgumentError(f'Invalid presets: {presets}')
self.preset_records = {}
for p in presets:
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
raise InvalidArgumentError(f'Invalid name: {p.name}')
self.preset_records[p.index] = p
# associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on('connection') # type: ignore
def on_connection(connection: Connection) -> None:
@connection.on('disconnection') # type: ignore
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
@connection.on('pairing') # type: ignore
def on_pairing(*_: Any) -> None:
self.on_incoming_paired_connection(connection)
if connection.peer_resolvable_address:
self.on_incoming_paired_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(self.server_features),
)
self.hearing_aid_preset_control_point = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.INDICATE
),
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(
write=self._on_write_hearing_aid_preset_control_point
),
)
self.active_preset_index_characteristic = gatt.Characteristic(
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
)
super().__init__(
[
self.hearing_aid_features_characteristic,
self.hearing_aid_preset_control_point,
self.active_preset_index_characteristic,
]
)
def on_incoming_paired_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ?
self.currently_connected_clients.add(connection)
if (
connection.peer_address
not in self.preset_changed_operations_history_per_device
):
self.preset_changed_operations_history_per_device[
connection.peer_address
] = []
return
async def on_connection_async() -> None:
# Send all the PresetChangedOperation that occur when not connected
await self._preset_changed_operation(connection)
# Update the active preset index if needed
await self.notify_active_preset_for_connection(connection)
connection.abort_on('disconnection', on_connection_async())
def _on_read_active_preset_index(
self, __connection__: Optional[Connection]
) -> bytes:
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
def on_forget(self, addr: Address) -> None:
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Optional[Connection], value: bytes
):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
self.read_presets_request_in_progress = True
start_index = value[1]
if start_index == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
num_presets = value[2]
if num_presets == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
presets = [
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response(
self, connection: Connection, presets: List[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=bytes(
[
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
i == len(presets) - 1,
]
)
+ bytes(preset),
)
finally:
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
self.read_presets_request_in_progress = False
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot set active preset as unavailable')
preset = self.preset_records[index]
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
op_list = self.preset_changed_operations_history_per_device.get(
connection.peer_address, []
)
# Notification will be sent in index order
def get_op_index(op: PresetChangedOperation) -> int:
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
return op.additional_parameters.prev_index
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
try:
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(
self, connection: Optional[Connection], value: bytes
):
assert connection
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.writable
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
):
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
name = value[2:].decode('utf-8')
if not name or len(name) > 40:
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
preset.name = name
await self.generic_update(
PresetChangedOperation(
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
PresetChangedOperation.Generic(index, preset),
)
)
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
if (
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
== self.active_preset_index
):
# Nothing to do, peer is already updated
return
await connection.device.notify_subscriber(
connection,
attribute=self.active_preset_index_characteristic,
value=bytes([self.active_preset_index]),
)
self.active_preset_index_per_device[connection.peer_address] = (
self.active_preset_index
)
async def notify_active_preset(self) -> None:
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.is_available
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
):
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if index == self.active_preset_index:
# Already at correct value
return
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(
self, connection: Optional[Connection], value: bytes
):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
'''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
if not first_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Optional[Connection], value: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(connection, value)
# TODO (low priority) inform other server of the change
async def _on_set_next_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, False)
# TODO (low priority) inform other server of the change
async def _on_set_previous_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, True)
# TODO (low priority) inform other server of the change
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.server_features = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0],
'B',
)
self.hearing_aid_preset_control_point = (
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
)[0]
)
self.active_preset_index = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0],
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
)

View File

@@ -1,210 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Published Audio Capabilities Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import Optional, Sequence, Union
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
from bumble.profiles import le_audio
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]

View File

@@ -1,89 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Telephony and Media Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import logging
import struct
from bumble.gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Role(enum.IntFlag):
CALL_GATEWAY = 1 << 0
CALL_TERMINAL = 1 << 1
UNICAST_MEDIA_SENDER = 1 << 2
UNICAST_MEDIA_RECEIVER = 1 << 3
BROADCAST_MEDIA_SENDER = 1 << 4
BROADCAST_MEDIA_RECEIVER = 1 << 5
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioService(TemplateService):
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
def __init__(self, role: Role):
self.role_characteristic = Characteristic(
GATT_TMAP_ROLE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', int(role)),
)
super().__init__([self.role_characteristic])
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = TelephonyAndMediaAudioService
role: DelegatedCharacteristicAdapter
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0],
),
)

View File

@@ -24,7 +24,7 @@ from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional, Sequence
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
@@ -88,7 +88,6 @@ class VolumeControlService(gatt.TemplateService):
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
included_services: Sequence[gatt.Service] = (),
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
@@ -118,12 +117,11 @@ class VolumeControlService(gatt.TemplateService):
)
super().__init__(
characteristics=[
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
],
included_services=list(included_services),
]
)
@property

View File

@@ -1,110 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number & 0xFFFF
self.timestamp = timestamp & 0xFFFFFFFF
self.timestamp_seconds = 0.0
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack(
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)

View File

@@ -764,9 +764,7 @@ class Session:
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
self.oob_data_flag = (
1 if pairing_config.oob and pairing_config.oob.peer_data else 0
)
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses
self_address = connection.self_resolvable_address or connection.self_address
@@ -1016,10 +1014,8 @@ class Session:
self.send_command(response)
def send_pairing_confirm_command(self) -> None:
if self.pairing_method != PairingMethod.OOB:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
@@ -1082,19 +1078,11 @@ class Session:
)
def send_identity_address_command(self) -> None:
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
identity_address = self.manager.device.public_address
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
identity_address = self.manager.device.static_address
else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
identity_address = {
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
@@ -1739,6 +1727,7 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
@@ -1746,22 +1735,6 @@ class Session:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
elif self.pairing_method == PairingMethod.OOB:
if self.is_initiator:
if self.peer_oob_data:
rb = self.peer_oob_data.r
ra = self.r
else:
rb = bytes(16)
ra = self.r
else:
if self.peer_oob_data:
ra = self.peer_oob_data.r
rb = self.r
else:
ra = bytes(16)
rb = self.r
else:
return

View File

@@ -70,9 +70,6 @@ def get_ini_dir() -> Optional[pathlib.Path]:
elif sys.platform == 'linux':
if xdg_runtime_dir := os.environ.get('XDG_RUNTIME_DIR', None):
return pathlib.Path(xdg_runtime_dir)
tmpdir = os.environ.get('TMPDIR', '/tmp')
if pathlib.Path(tmpdir).is_dir():
return pathlib.Path(tmpdir)
elif sys.platform == 'win32':
if local_app_data_dir := os.environ.get('LOCALAPPDATA', None):
return pathlib.Path(local_app_data_dir) / 'Temp'

View File

@@ -248,28 +248,26 @@ class AsyncPipeSink:
# -----------------------------------------------------------------------------
class BaseSource:
class ParserSource:
"""
Base class designed to be subclassed by transport-specific source classes
"""
terminated: asyncio.Future[None]
sink: Optional[TransportSink]
parser: PacketParser
def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
self.sink = None
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
self.parser.set_packet_sink(sink)
def on_transport_lost(self) -> None:
if not self.terminated.done():
self.terminated.set_result(None)
if self.sink:
if hasattr(self.sink, 'on_transport_lost'):
self.sink.on_transport_lost()
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
@@ -282,23 +280,6 @@ class BaseSource:
pass
# -----------------------------------------------------------------------------
class ParserSource(BaseSource):
"""
Base class for sources that use an HCI parser.
"""
parser: PacketParser
def __init__(self) -> None:
super().__init__()
self.parser = PacketParser()
def set_packet_sink(self, sink: TransportSink) -> None:
super().set_packet_sink(sink)
self.parser.set_packet_sink(sink)
# -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None:

View File

@@ -23,7 +23,7 @@ import time
import usb.core
import usb.util
from typing import Optional, Set
from typing import Optional
from usb.core import Device as UsbDevice
from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
@@ -46,11 +46,6 @@ RESET_DELAY = 3
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Global
# -----------------------------------------------------------------------------
devices_in_use: Set[int] = set()
# -----------------------------------------------------------------------------
async def open_pyusb_transport(spec: str) -> Transport:
@@ -222,8 +217,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
await self.source.stop()
await self.sink.stop()
usb.util.release_interface(self.device, 0)
if devices_in_use and device.address in devices_in_use:
devices_in_use.remove(device.address)
usb_find = usb.core.find
try:
@@ -240,18 +233,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
spec = spec[1:]
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = None
devices = usb_find(
find_all=True, idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
for d in devices:
if d.address in devices_in_use:
continue
device = d
devices_in_use.add(d.address)
break
if device is None:
raise ValueError('device already in use')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
elif '-' in spec:
def device_path(device):

View File

@@ -24,7 +24,7 @@ import platform
import usb1
from bumble.transport.common import Transport, BaseSource, TransportInitError
from bumble.transport.common import Transport, ParserSource, TransportInitError
from bumble import hci
from bumble.colors import color
@@ -139,7 +139,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.packets.put_nowait(packet)
def transfer_callback(self, transfer):
self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
self.acl_out_transfer_ready.release()
status = transfer.getStatus()
# pylint: disable=no-member
@@ -208,7 +208,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError:
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, BaseSource):
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, acl_in, events_in):
super().__init__()
self.device = device
@@ -285,13 +285,7 @@ async def open_usb_transport(spec: str) -> Transport:
packet = await self.queue.get()
except asyncio.CancelledError:
return
if self.sink:
try:
self.sink.on_packet(packet)
except Exception as error:
logger.exception(
color(f'!!! Exception in sink.on_packet: {error}', 'red')
)
self.parser.feed_data(packet)
def close(self):
self.closed = True

View File

@@ -1,7 +1,7 @@
EXAMPLES
========
The project includes a few simple example applications to illustrate some of the ways the library APIs can be used.
The project includes a few simple example applications the illustrate some of the ways the library APIs can be used.
These examples include:
## `battery_service.py`
@@ -25,9 +25,6 @@ An app that implements a virtual Bluetooth speaker that can receive audio.
## `run_advertiser.py`
An app that runs a simple device that just advertises (BLE).
## `run_cig_setup.py`
An app that creates a simple CIG containing two CISes. **Note**: If using the example config file (e.g. `device1.json`), the `address` needs to be removed, so that the devices are given different random addresses.
## `run_classic_connect.py`
An app that connects to a Bluetooth Classic device and prints its services.
@@ -45,9 +42,6 @@ An app that connected to a device (BLE) and encrypts the connection.
## `run_controller.py`
Creates two linked controllers, attaches one to a transport, and the other to a local host with a GATT server application. This can be used, for example, to attach a virtual controller to a native stack, like BlueZ on Linux, and use the native tools, like `bluetoothctl`, to scan and connect to the GATT server included in the example.
## `run_csis_servers.py`
Runs CSIS servers on two devices to form a Coordinated Set. **Note**: If using the example config file (e.g. `device1.json`), the `address` needs to be removed, so that the devices are given different random addresses.
## `run_gatt_client_and_server.py`
Runs a local GATT server and GATT client, connected to each other. The GATT client discovers and logs all the services and characteristics exposed by the GATT server

View File

@@ -1,95 +0,0 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble ASHA Sink</span>
</div>
</nav>
<br>
<div class="container">
<div class="row">
<div class="col-auto">
<button id="connect-audio" class="btn btn-danger" onclick="connectAudio()">Connect Audio</button>
</div>
</div>
<hr>
<div class="row">
<div class="col-4">
<label class="form-label">Browser Gain</label>
<input type="range" class="form-range" id="browser-gain" min="0" max="2" value="1" step="0.1"
onchange="setGain()">
</div>
</div>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let atResponseInput = document.getElementById("at_response")
let gainInput = document.getElementById('browser-gain')
let log = document.getElementById("log")
let socket = new WebSocket('ws://localhost:8888');
let sampleRate = 0;
let player;
socket.binaryType = "arraybuffer";
socket.onopen = _ => {
log.textContent += 'SOCKET OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'SOCKET CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'SOCKET ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
log.textContent += `<-- ${event.data}\n`
} else {
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 1,
sampleRate: 16000,
flushTime: 20,
});
player.volume(gainInput.value);
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
function setGain() {
if (player != null) {
player.volume(gainInput.value);
}
}
</script>
</div>
</body>
</html>

View File

@@ -1,6 +1,5 @@
{
"name": "Bumble Aid Left",
"address": "F1:F2:F3:F4:F5:F6",
"identity_address_type": 1,
"keystore": "JsonKeyStore"
}
}

View File

@@ -1,6 +1,5 @@
{
"name": "Bumble Aid Right",
"address": "F7:F8:F9:FA:FB:FC",
"identity_address_type": 1,
"keystore": "JsonKeyStore"
}
}

View File

@@ -61,23 +61,20 @@ def codec_capabilities():
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_4
| SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| SbcMediaCodecInformation.AllocationMethod.SNR,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),

View File

@@ -33,6 +33,8 @@ from bumble.avdtp import (
Listener,
)
from bumble.a2dp import (
SBC_JOINT_STEREO_CHANNEL_MODE,
SBC_LOUDNESS_ALLOCATION_METHOD,
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation,
@@ -57,12 +59,12 @@ def codec_capabilities():
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_44100,
channel_mode=SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
media_codec_information=SbcMediaCodecInformation.from_discrete_values(
sampling_frequency=44100,
channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE,
block_length=16,
subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
@@ -71,9 +73,11 @@ def codec_capabilities():
# -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu)
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(codec_capabilities(), packet_pump)
protocol.add_source(packet_source.codec_capabilities, packet_pump)
# -----------------------------------------------------------------------------
@@ -93,9 +97,11 @@ async def stream_packets(read_function, protocol):
print(f'### Selected sink: {sink.seid}')
# Stream the packets
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu)
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(codec_capabilities(), packet_pump)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink)
await stream.start()
await asyncio.sleep(5)

View File

@@ -16,104 +16,192 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import sys
import os
import logging
import websockets
from typing import Optional
from bumble import decoder
from bumble import gatt
from bumble import l2cap
from bumble.core import AdvertisingData
from bumble.device import Device, AdvertisingParameters
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.profiles import asha
ws_connection: Optional[websockets.WebSocketServerProtocol] = None
g722_decoder = decoder.G722Decoder()
from bumble.core import UUID
from bumble.gatt import Service, Characteristic, CharacteristicValue
async def ws_server(ws_client: websockets.WebSocketServerProtocol, path: str):
del path
global ws_connection
ws_connection = ws_client
async for message in ws_client:
print(message)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID(
'6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties'
)
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID(
'f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint'
)
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID(
'38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus'
)
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID(
'2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT'
)
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) != 3:
print('Usage: python run_asha_sink.py <device-config> <transport-spec>')
print('example: python run_asha_sink.py device1.json usb:0')
if len(sys.argv) != 4:
print(
'Usage: python run_asha_sink.py <device-config> <transport-spec> '
'<audio-file>'
)
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
def on_audio_packet(packet: bytes) -> None:
global ws_connection
if ws_connection:
offset = 1
while offset < len(packet):
pcm_data = g722_decoder.decode_frame(packet[offset : offset + 80])
offset += 80
asyncio.get_running_loop().create_task(ws_connection.send(pcm_data))
else:
logging.info("No active client")
# Handler for audio control commands
def on_audio_control_point_write(_connection, value):
print('--- AUDIO CONTROL POINT Write:', value.hex())
opcode = value[0]
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(
f'### START: codec={value[1]}, audio_type={audio_type}, '
f'volume={value[3]}, otherstate={value[4]}'
)
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
asha_service = asha.AshaService(
capability=0,
hisyncid=b'\x01\x02\x03\x04\x05\x06\x07\x08',
device=device,
audio_sink=on_audio_packet,
# Respond with a status
asyncio.create_task(
device.notify_subscribers(audio_status_characteristic, force=True)
)
# Handler for volume control
def on_volume_write(_connection, value):
print('--- VOLUME Write:', value[0])
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
print('<<< Voice data received:', data.hex())
audio_out.write(data)
channel.sink = on_data
server = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=8), handler=on_coc
)
print(f'### LE_PSM_OUT = {server.psm}')
# Add the ASHA service to the GATT server
read_only_properties_characteristic = Characteristic(
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01,
0x02,
0x03,
0x04,
0x05,
0x06,
0x07,
0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00,
0x00, # Render Delay
0x00,
0x00, # RFU
0x02,
0x00, # Codec IDs [G.722 at 16 kHz]
]
),
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', server.psm),
)
device.add_service(
Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic,
],
)
)
device.add_service(asha_service)
# Set the advertising data
advertising_data = (
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device.name, 'utf-8'),
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(ASHA_SERVICE)
+ bytes(
[
0x01, # Protocol Version
0x00, # Capability
0x01,
0x02,
0x03,
0x04, # Truncated HiSyncID
]
),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(gatt.GATT_ASHA_SERVICE),
),
]
)
),
]
)
+ asha_service.get_advertising_data()
)
# Go!
await device.power_on()
await device.create_advertising_set(
auto_restart=True,
advertising_data=advertising_data,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await device.start_advertising(auto_restart=True)
await websockets.serve(ws_server, port=8888)
await hci_transport.source.terminated
await hci_transport.source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(
level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper(),
format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -60,23 +60,20 @@ def codec_capabilities():
return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
media_codec_information=a2dp.SbcMediaCodecInformation(
sampling_frequency=a2dp.SbcMediaCodecInformation.SamplingFrequency.SF_48000
| a2dp.SbcMediaCodecInformation.SamplingFrequency.SF_44100
| a2dp.SbcMediaCodecInformation.SamplingFrequency.SF_32000
| a2dp.SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=a2dp.SbcMediaCodecInformation.ChannelMode.MONO
| a2dp.SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| a2dp.SbcMediaCodecInformation.ChannelMode.STEREO
| a2dp.SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=a2dp.SbcMediaCodecInformation.BlockLength.BL_4
| a2dp.SbcMediaCodecInformation.BlockLength.BL_8
| a2dp.SbcMediaCodecInformation.BlockLength.BL_12
| a2dp.SbcMediaCodecInformation.BlockLength.BL_16,
subbands=a2dp.SbcMediaCodecInformation.Subbands.S_4
| a2dp.SbcMediaCodecInformation.Subbands.S_8,
allocation_method=a2dp.SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| a2dp.SbcMediaCodecInformation.AllocationMethod.SNR,
media_codec_information=a2dp.SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
a2dp.SBC_MONO_CHANNEL_MODE,
a2dp.SBC_DUAL_CHANNEL_MODE,
a2dp.SBC_STEREO_CHANNEL_MODE,
a2dp.SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
a2dp.SBC_LOUDNESS_ALLOCATION_METHOD,
a2dp.SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),

View File

@@ -36,10 +36,13 @@ from bumble.transport import open_transport_or_link
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_cig_setup.py <config-file> '
'Usage: run_cig_setup.py <config-file>'
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
)
print('example: run_cig_setup.py device1.json hci-socket:0 hci-socket:1')
print(
'example: run_cig_setup.py device1.json'
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
)
return
print('<<< connecting to HCI...')
@@ -62,18 +65,18 @@ async def main() -> None:
advertising_set = await devices[0].create_advertising_set()
connection = await devices[1].connect(
devices[0].random_address, own_address_type=OwnAddressType.RANDOM
devices[0].public_address, own_address_type=OwnAddressType.PUBLIC
)
cid_ids = [2, 3]
cis_handles = await devices[1].setup_cig(
cig_id=1,
cis_id=cid_ids,
sdu_interval=(10000, 255),
sdu_interval=(10000, 0),
framing=0,
max_sdu=(120, 0),
retransmission_number=13,
max_transport_latency=(100, 5),
max_transport_latency=(100, 0),
)
def on_cis_request(

View File

@@ -38,10 +38,13 @@ from bumble.transport import open_transport_or_link
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_csis_servers.py <config-file> '
'Usage: run_cig_setup.py <config-file>'
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
)
print('example: run_csis_servers.py device1.json ' 'hci-socket:0 hci-socket:1')
print(
'example: run_cig_setup.py device1.json'
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
)
return
print('<<< connecting to HCI...')

View File

@@ -1,107 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble import att
from bumble.profiles.hap import (
HearingAccessService,
HearingAidFeatures,
HearingAidType,
PresetSynchronizationSupport,
IndependentPresets,
DynamicPresets,
WritablePresetsSupport,
PresetRecord,
)
from bumble.transport import open_transport_or_link
server_features = HearingAidFeatures(
HearingAidType.MONAURAL_HEARING_AID,
PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED,
IndependentPresets.IDENTICAL_PRESET_RECORD,
DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE,
WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED,
)
foo_preset = PresetRecord(1, "foo preset")
bar_preset = PresetRecord(50, "bar preset")
foobar_preset = PresetRecord(5, "foobar preset")
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_hap_server.py <config-file> <transport-spec-for-device>')
print('example: run_hap_server.py device1.json pty:hci_pty')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
hap = HearingAccessService(
device, server_features, [foo_preset, bar_preset, foobar_preset]
)
device.add_service(hap)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble HearingAccessService', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(HearingAccessService.UUID),
),
]
)
)
await device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
)
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -21,7 +21,7 @@ import os
import logging
import json
import websockets
import struct
from bumble.colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
@@ -30,7 +30,9 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID,
UUID,
)
from bumble.hci import Address
from bumble.hid import (
Device as HID_Device,
HID_CONTROL_PSM,
@@ -38,17 +40,20 @@ from bumble.hid import (
Message,
)
from bumble.sdp import (
Client as SDP_Client,
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices
@@ -425,7 +430,7 @@ deviceData = DeviceData()
# -----------------------------------------------------------------------------
async def keyboard_device(hid_device: HID_Device):
async def keyboard_device(hid_device):
# Start a Websocket server to receive events from a web page
async def serve(websocket, _path):
@@ -471,9 +476,9 @@ async def keyboard_device(hid_device: HID_Device):
# limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y))
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
">bb", x, y
)
x_cord = x.to_bytes(signed=True)
y_cord = y.to_bytes(signed=True)
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK:
pass
@@ -510,9 +515,7 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
retValue = hid_device.GetSetStatus()
print(
"GET_REPORT report_id: "
@@ -552,7 +555,8 @@ async def main() -> None:
def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes
) -> HID_Device.GetSetStatus:
):
retValue = hid_device.GetSetStatus()
print(
"SET_REPORT report_id: "
+ str(report_id)
@@ -564,33 +568,33 @@ async def main() -> None:
+ str(data)
)
if report_type == Message.ReportType.FEATURE_REPORT:
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData):
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData):
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3:
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
status = HID_Device.GetSetReturn.SUCCESS
retValue.status = hid_device.GetSetReturn.SUCCESS
else:
status = HID_Device.GetSetReturn.SUCCESS
retValue.status = hid_device.GetSetReturn.SUCCESS
return HID_Device.GetSetStatus(status=status)
return retValue
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
return HID_Device.GetSetStatus(
data=bytes([protocol_mode]),
status=hid_device.GetSetReturn.SUCCESS,
)
def on_get_protocol_cb():
retValue = hid_device.GetSetStatus()
retValue.data = protocol_mode.to_bytes()
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
return HID_Device.GetSetStatus(
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')

View File

@@ -35,13 +35,15 @@ from bumble.hci import (
CodingFormat,
OwnAddressType,
)
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import (
CodecSpecificCapabilities,
ContextType,
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
UnicastServerAdvertisingData,
)
from bumble.profiles.mcp import (
@@ -50,7 +52,7 @@ from bumble.profiles.mcp import (
MediaState,
MediaControlPointOpcode,
)
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link
from typing import Optional

View File

@@ -34,8 +34,8 @@ from bumble.hci import (
CodingFormat,
HCI_IsoDataPacket,
)
from bumble.profiles.ascs import AseStateMachine, AudioStreamControlService
from bumble.profiles.bap import (
AseStateMachine,
UnicastServerAdvertisingData,
CodecSpecificConfiguration,
CodecSpecificCapabilities,
@@ -43,10 +43,13 @@ from bumble.profiles.bap import (
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
)
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link

View File

@@ -30,7 +30,6 @@ from bumble.hci import (
CodingFormat,
OwnAddressType,
)
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import (
UnicastServerAdvertisingData,
CodecSpecificCapabilities,
@@ -38,8 +37,10 @@ from bumble.profiles.bap import (
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
)
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.vcp import VolumeControlService

View File

@@ -142,7 +142,7 @@ class MainActivity : ComponentActivity() {
::runRfcommClient,
::runRfcommServer,
::runL2capClient,
::runL2capServer,
::runL2capServer
)
}
@@ -166,8 +166,6 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> runRfcommServer()
"l2cap-client" -> runL2capClient()
"l2cap-server" -> runL2capServer()
"scan-start" -> runScan(true)
"stop-start" -> runScan(false)
}
}
}
@@ -192,11 +190,6 @@ class MainActivity : ComponentActivity() {
l2capServer?.run()
}
private fun runScan(startScan: Boolean) {
val scan = bluetoothAdapter?.let { Scan(it) }
scan?.run(startScan)
}
@SuppressLint("MissingPermission")
fun becomeDiscoverable() {
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
@@ -213,7 +206,7 @@ fun MainView(
runRfcommClient: () -> Unit,
runRfcommServer: () -> Unit,
runL2capClient: () -> Unit,
runL2capServer: () -> Unit,
runL2capServer: () -> Unit
) {
BTBenchTheme {
val scrollState = rememberScrollState()

View File

@@ -1,38 +0,0 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.le.ScanCallback
import android.bluetooth.le.ScanResult
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.scan")
class Scan(val bluetoothAdapter: BluetoothAdapter) {
@SuppressLint("MissingPermission")
fun run(startScan: Boolean) {
var bluetoothLeScanner = bluetoothAdapter.bluetoothLeScanner
val scanCallback = object : ScanCallback() {
override fun onScanResult(callbackType: Int, result: ScanResult?) {
super.onScanResult(callbackType, result)
val device: BluetoothDevice? = result?.device
val deviceName = device?.name ?: "Unknown"
val deviceAddress = device?.address ?: "Unknown"
Log.info("Device found: $deviceName ($deviceAddress)")
}
override fun onScanFailed(errorCode: Int) {
// Handle scan failure
Log.warning("Scan failed with error code: $errorCode")
}
}
if (startScan) {
bluetoothLeScanner?.startScan(scanCallback)
} else {
bluetoothLeScanner?.stopScan(scanCallback)
}
}
}

View File

@@ -70,7 +70,6 @@ console_scripts =
bumble-usb-probe = bumble.apps.usb_probe:main
bumble-link-relay = bumble.apps.link_relay.link_relay:main
bumble-bench = bumble.apps.bench:main
bumble-player = bumble.apps.player.player:main
bumble-speaker = bumble.apps.speaker.speaker:main
bumble-pandora-server = bumble.apps.pandora_server:main
bumble-rtk-util = bumble.tools.rtk_util:main
@@ -100,7 +99,7 @@ development =
types-protobuf >= 4.21.0
wasmtime == 20.0.0
avatar =
pandora-avatar == 0.0.10
pandora-avatar == 0.0.9
rootcanal == 1.10.0 ; python_version>='3.10'
pandora =
bt-test-interfaces >= 0.0.6

View File

@@ -33,16 +33,20 @@ from bumble.avdtp import (
Protocol,
Listener,
MediaCodecCapabilities,
MediaPacket,
AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_TSEP_SNK,
A2DP_SBC_CODEC_TYPE,
)
from bumble.a2dp import (
AacMediaCodecInformation,
OpusMediaCodecInformation,
SbcMediaCodecInformation,
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
)
from bumble.rtp import MediaPacket
# -----------------------------------------------------------------------------
# Logging
@@ -121,12 +125,12 @@ def source_codec_capabilities():
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_44100,
channel_mode=SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
media_codec_information=SbcMediaCodecInformation.from_discrete_values(
sampling_frequency=44100,
channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE,
block_length=16,
subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
@@ -138,23 +142,20 @@ def sink_codec_capabilities():
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_4
| SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| SbcMediaCodecInformation.AllocationMethod.SNR,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
@@ -272,125 +273,7 @@ async def test_source_sink_1():
# -----------------------------------------------------------------------------
def test_sbc_codec_specific_information():
sbc_info = SbcMediaCodecInformation.from_bytes(bytes.fromhex("3fff0235"))
assert (
sbc_info.sampling_frequency
== SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_48000
)
assert (
sbc_info.channel_mode
== SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO
)
assert (
sbc_info.block_length
== SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16
)
assert (
sbc_info.subbands
== SbcMediaCodecInformation.Subbands.S_4 | SbcMediaCodecInformation.Subbands.S_8
)
assert (
sbc_info.allocation_method
== SbcMediaCodecInformation.AllocationMethod.SNR
| SbcMediaCodecInformation.AllocationMethod.LOUDNESS
)
assert sbc_info.minimum_bitpool_value == 2
assert sbc_info.maximum_bitpool_value == 53
sbc_info2 = SbcMediaCodecInformation(
SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_48000,
SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
SbcMediaCodecInformation.Subbands.S_4 | SbcMediaCodecInformation.Subbands.S_8,
SbcMediaCodecInformation.AllocationMethod.SNR
| SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
2,
53,
)
assert sbc_info == sbc_info2
assert bytes(sbc_info2) == bytes.fromhex("3fff0235")
# -----------------------------------------------------------------------------
def test_aac_codec_specific_information():
aac_info = AacMediaCodecInformation.from_bytes(bytes.fromhex("f0018c83e800"))
assert (
aac_info.object_type
== AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LC
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LTP
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_SCALABLE
)
assert (
aac_info.sampling_frequency
== AacMediaCodecInformation.SamplingFrequency.SF_44100
| AacMediaCodecInformation.SamplingFrequency.SF_48000
)
assert (
aac_info.channels
== AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO
)
assert aac_info.vbr == 1
assert aac_info.bitrate == 256000
aac_info2 = AacMediaCodecInformation(
AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LC
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LTP
| AacMediaCodecInformation.ObjectType.MPEG_4_AAC_SCALABLE,
AacMediaCodecInformation.SamplingFrequency.SF_44100
| AacMediaCodecInformation.SamplingFrequency.SF_48000,
AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO,
1,
256000,
)
assert aac_info == aac_info2
assert bytes(aac_info2) == bytes.fromhex("f0018c83e800")
# -----------------------------------------------------------------------------
def test_opus_codec_specific_information():
opus_info = OpusMediaCodecInformation.from_bytes(bytes([0x92]))
assert opus_info.vendor_id == OpusMediaCodecInformation.VENDOR_ID
assert opus_info.codec_id == OpusMediaCodecInformation.CODEC_ID
assert opus_info.frame_size == OpusMediaCodecInformation.FrameSize.FS_20MS
assert opus_info.channel_mode == OpusMediaCodecInformation.ChannelMode.STEREO
assert (
opus_info.sampling_frequency
== OpusMediaCodecInformation.SamplingFrequency.SF_48000
)
opus_info2 = OpusMediaCodecInformation(
OpusMediaCodecInformation.ChannelMode.STEREO,
OpusMediaCodecInformation.FrameSize.FS_20MS,
OpusMediaCodecInformation.SamplingFrequency.SF_48000,
)
assert opus_info2 == opus_info
assert opus_info2.value == bytes([0x92])
# -----------------------------------------------------------------------------
async def async_main():
test_sbc_codec_specific_information()
test_aac_codec_specific_information()
test_opus_codec_specific_information()
async def run_test_self():
await test_self_connection()
await test_source_sink_1()
@@ -398,4 +281,4 @@ async def async_main():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())
asyncio.run(run_test_self())

View File

@@ -1,494 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import pytest
import pytest_asyncio
from bumble import device
from bumble.att import ATT_Error
from bumble.profiles.aics import (
Mute,
AICSService,
AudioInputState,
AICSServiceProxy,
GainMode,
AudioInputStatus,
AudioInputControlPointOpCode,
ErrorCode,
)
from bumble.profiles.vcp import VolumeControlService, VolumeControlServiceProxy
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
aics_service = AICSService()
vcp_service = VolumeControlService(
volume_setting=32, muted=1, volume_flags=1, included_services=[aics_service]
)
@pytest_asyncio.fixture
async def aics_client():
devices = TwoDevices()
devices[0].add_service(vcp_service)
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
vcp_client = await peer.discover_service_and_create_proxy(VolumeControlServiceProxy)
assert vcp_client
included_services = await peer.discover_included_services(vcp_client.service_proxy)
assert included_services
aics_service_discovered = included_services[0]
await peer.discover_characteristics(service=aics_service_discovered)
aics_client = AICSServiceProxy(aics_service_discovered)
yield aics_client
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_service(aics_client: AICSServiceProxy):
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL,
change_counter=0,
)
assert await aics_client.gain_settings_properties.read_value() == (1, 0, 255)
assert await aics_client.audio_input_status.read_value() == (
AudioInputStatus.ACTIVE
)
@pytest.mark.asyncio
async def test_wrong_opcode_raise_error(aics_client: AICSServiceProxy):
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
0xFF,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.OPCODE_NOT_SUPPORTED
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_automatic_only(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
# Unchanged
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.AUTOMATIC_ONLY,
change_counter=0,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_automatic(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
# Unchanged
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.AUTOMATIC,
change_counter=0,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_MANUAL(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=gain_settings,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL,
change_counter=change_counter,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_MANUAL_ONLY(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=gain_settings,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL_ONLY,
change_counter=change_counter,
)
@pytest.mark.asyncio
async def test_unmute_when_muted(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.MUTED
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.UNMUTE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.NOT_MUTED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_unmute_when_mute_disabled(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.DISABLED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.UNMUTE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.MUTE_DISABLED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.DISABLED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_mute_when_not_muted(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.NOT_MUTED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.MUTE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.MUTED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_mute_when_mute_disabled(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.DISABLED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.MUTE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.MUTE_DISABLED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.DISABLED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_automatic(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_already_manual(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
)
)
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_manual_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_automatic_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_manual(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_already_automatic(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
)
)
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_manual_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_automatic_only(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_audio_input_description_initial_value(aics_client: AICSServiceProxy):
description = await aics_client.audio_input_description.read_value()
assert description.decode('utf-8') == "Bluetooth"
@pytest.mark.asyncio
async def test_audio_input_description_write_and_read(aics_client: AICSServiceProxy):
new_description = "Line Input".encode('utf-8')
await aics_client.audio_input_description.write_value(new_description)
description = await aics_client.audio_input_description.read_value()
assert description == new_description

View File

@@ -1,163 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
import struct
from unittest import mock
from bumble import device as bumble_device
from bumble.profiles import asha
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
HI_SYNC_ID = b'\x00\x01\x02\x03\x04\x05\x06\x07'
TIMEOUT = 0.1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_only_properties():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
protocol_version=0x01,
capability=0x02,
feature_map=0x03,
render_delay_milliseconds=0x04,
supported_codecs=0x05,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
read_only_properties = (
await asha_client.read_only_properties_characteristic.read_value()
)
(
protocol_version,
capabilities,
hi_sync_id,
feature_map,
render_delay_milliseconds,
_,
supported_codecs,
) = struct.unpack("<BB8sBHHH", read_only_properties)
assert protocol_version == 0x01
assert capabilities == 0x02
assert hi_sync_id == HI_SYNC_ID
assert feature_map == 0x03
assert render_delay_milliseconds == 0x04
assert supported_codecs == 0x05
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_psm():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
psm = (await asha_client.psm_characteristic.read_value())[0]
assert psm == asha_service.psm
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_audio_control_point_start():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
status_notifications = asyncio.Queue()
await asha_client.audio_status_point_characteristic.subscribe(
status_notifications.put_nowait
)
start_cb = mock.MagicMock()
asha_service.on('started', start_cb)
await asha_client.audio_control_point_characteristic.write_value(
bytes(
[asha.OpCode.START, asha.Codec.G_722_16KHZ, asha.AudioType.MEDIA, 0, 1]
)
)
status = (await asyncio.wait_for(status_notifications.get(), TIMEOUT))[0]
assert status == asha.AudioStatus.OK
start_cb.assert_called_once()
assert asha_service.active_codec == asha.Codec.G_722_16KHZ
assert asha_service.volume == 0
assert asha_service.other_state == 1
assert asha_service.audio_type == asha.AudioType.MEDIA
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_audio_control_point_stop():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
status_notifications = asyncio.Queue()
await asha_client.audio_status_point_characteristic.subscribe(
status_notifications.put_nowait
)
stop_cb = mock.MagicMock()
asha_service.on('stopped', stop_cb)
await asha_client.audio_control_point_characteristic.write_value(
bytes([asha.OpCode.STOP])
)
status = (await asyncio.wait_for(status_notifications.get(), TIMEOUT))[0]
assert status == asha.AudioStatus.OK
stop_cb.assert_called_once()
assert asha_service.active_codec is None
assert asha_service.volume is None
assert asha_service.other_state is None
assert asha_service.audio_type is None

View File

@@ -23,12 +23,13 @@ from bumble.avdtp import (
AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY,
AVDTP_SET_CONFIGURATION,
Message,
MediaPacket,
Get_Capabilities_Response,
Set_Configuration_Command,
Set_Configuration_Response,
ServiceCapabilities,
MediaCodecCapabilities,
)
from bumble.rtp import MediaPacket
# -----------------------------------------------------------------------------

View File

@@ -23,9 +23,8 @@ import logging
from bumble import device
from bumble.hci import CodecID, CodingFormat
from bumble.profiles.ascs import (
AudioStreamControlService,
AudioStreamControlServiceProxy,
from bumble.profiles.bap import (
AudioLocation,
AseStateMachine,
ASE_Operation,
ASE_Config_Codec,
@@ -36,9 +35,6 @@ from bumble.profiles.ascs import (
ASE_Receiver_Stop_Ready,
ASE_Release,
ASE_Update_Metadata,
)
from bumble.profiles.bap import (
AudioLocation,
SupportedFrameDuration,
SupportedSamplingFrequency,
SamplingFrequency,
@@ -46,9 +42,9 @@ from bumble.profiles.bap import (
CodecSpecificCapabilities,
CodecSpecificConfiguration,
ContextType,
)
from bumble.profiles.pacs import (
PacRecord,
AudioStreamControlService,
AudioStreamControlServiceProxy,
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)

View File

@@ -1,146 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from bumble import hci
from bumble.profiles import bass
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def basic_operation_check(operation: bass.ControlPointOperation) -> None:
serialized = bytes(operation)
parsed = bass.ControlPointOperation.from_bytes(serialized)
assert bytes(parsed) == serialized
# -----------------------------------------------------------------------------
def test_operations() -> None:
op1 = bass.RemoteScanStoppedOperation()
basic_operation_check(op1)
op2 = bass.RemoteScanStartedOperation()
basic_operation_check(op2)
op3 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(),
)
basic_operation_check(op3)
op4 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(
bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')),
bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')),
),
)
basic_operation_check(op4)
op5 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(),
)
basic_operation_check(op5)
op6 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
),
)
basic_operation_check(op6)
op7 = bass.SetBroadcastCodeOperation(
7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf')
)
basic_operation_check(op7)
op8 = bass.RemoveSourceOperation(7)
basic_operation_check(op8)
# -----------------------------------------------------------------------------
def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None:
serialized = bytes(brs)
parsed = bass.BroadcastReceiveState.from_bytes(serialized)
assert parsed is not None
assert bytes(parsed) == serialized
def test_broadcast_receive_state() -> None:
subgroups = [
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
]
brs1 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.DECRYPTING,
b'',
subgroups,
)
basic_broadcast_receive_state_check(brs1)
brs2 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.BAD_CODE,
bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'),
subgroups,
)
basic_broadcast_receive_state_check(brs2)
# -----------------------------------------------------------------------------
async def run():
test_operations()
test_broadcast_receive_state()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -15,9 +15,8 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import random
import pytest
from bumble.codecs import AacAudioRtpPacket, BitReader, BitWriter
from bumble.codecs import AacAudioRtpPacket, BitReader
# -----------------------------------------------------------------------------
@@ -50,58 +49,19 @@ def test_reader():
assert value == int.from_bytes(data, byteorder='big')
def test_writer():
writer = BitWriter()
assert bytes(writer) == b''
for i in range(100):
for j in range(1, 10):
writer = BitWriter()
chunks = []
for k in range(j):
n_bits = random.randint(1, 32)
random_bits = random.getrandbits(n_bits)
chunks.append((n_bits, random_bits))
writer.write(random_bits, n_bits)
written_data = bytes(writer)
reader = BitReader(written_data)
for n_bits, written_bits in chunks:
read_bits = reader.read(n_bits)
assert read_bits == written_bits
def test_aac_rtp():
# pylint: disable=line-too-long
packet_data = bytes.fromhex(
'47fc0000b090800300202066000198000de120000000000000000000000000000000000000000000001c'
)
packet = AacAudioRtpPacket.from_bytes(packet_data)
packet = AacAudioRtpPacket(packet_data)
adts = packet.to_adts()
assert adts == bytes.fromhex(
'fff1508004fffc2066000198000de120000000000000000000000000000000000000000000001c'
)
payload = bytes(list(range(1, 200)))
rtp = AacAudioRtpPacket.for_simple_aac(44100, 2, payload)
assert rtp.audio_mux_element.payload == payload
assert (
rtp.audio_mux_element.stream_mux_config.audio_specific_config.sampling_frequency
== 44100
)
assert (
rtp.audio_mux_element.stream_mux_config.audio_specific_config.channel_configuration
== 2
)
rtp2 = AacAudioRtpPacket.from_bytes(bytes(rtp))
assert str(rtp2.audio_mux_element.stream_mux_config) == str(
rtp.audio_mux_element.stream_mux_config
)
assert rtp2.audio_mux_element.payload == rtp.audio_mux_element.payload
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_reader()
test_writer()
test_aac_rtp()

View File

@@ -536,16 +536,6 @@ async def test_cis_setup_failure():
await asyncio.wait_for(cis_create_task, _TIMEOUT)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():
devices = TwoDevices()
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
await devices[0].power_on()
assert devices[0].static_address != Address.ANY_RANDOM
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))

View File

@@ -47,10 +47,8 @@ from bumble.att import (
ATT_EXCHANGE_MTU_REQUEST,
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
ATT_Error,
ATT_Error_Response,
ATT_Read_By_Group_Type_Request,
ErrorCode,
)
from .test_utils import async_barrier
@@ -1249,32 +1247,6 @@ async def test_get_characteristics_by_uuid():
assert len(s) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_return_error():
[client, server] = LinkedDevices().devices[:2]
on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED))
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=on_write),
)
service = Service('ABCD', [characteristic])
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
async with Peer(connection) as peer:
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0]
with pytest.raises(ATT_Error) as e:
await c.write_value(b'', with_response=True)
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())

View File

@@ -1,232 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import pytest
import functools
import pytest_asyncio
import logging
import sys
from bumble import att, device
from bumble.profiles import hap
from .test_utils import TwoDevices
from bumble.keys import PairingKeys
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
foo_preset = hap.PresetRecord(1, "foo preset")
bar_preset = hap.PresetRecord(50, "bar preset")
foobar_preset = hap.PresetRecord(5, "foobar preset")
unavailable_preset = hap.PresetRecord(
78,
"foobar preset",
hap.PresetRecord.Property(
hap.PresetRecord.Property.Writable.CANNOT_BE_WRITTEN,
hap.PresetRecord.Property.IsAvailable.IS_UNAVAILABLE,
),
)
server_features = hap.HearingAidFeatures(
hap.HearingAidType.MONAURAL_HEARING_AID,
hap.PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED,
hap.IndependentPresets.IDENTICAL_PRESET_RECORD,
hap.DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE,
hap.WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED,
)
TIMEOUT = 0.1
async def assert_queue_is_empty(queue: asyncio.Queue):
assert queue.empty()
# Check that nothing is being added during TIMEOUT secondes
if sys.version_info >= (3, 11):
with pytest.raises(TimeoutError):
await asyncio.wait_for(queue.get(), TIMEOUT)
else:
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get(), TIMEOUT)
# -----------------------------------------------------------------------------
@pytest_asyncio.fixture
async def hap_client():
devices = TwoDevices()
devices[0].add_service(
hap.HearingAccessService(
devices[0],
server_features,
[foo_preset, bar_preset, foobar_preset, unavailable_preset],
)
)
await devices.setup_connection()
# TODO negotiate MTU > 49 to not truncate preset names
# Mock encryption.
devices.connections[0].encryption = 1 # type: ignore
devices.connections[1].encryption = 1 # type: ignore
devices[0].on_pairing(
devices.connections[0], devices.connections[0].peer_address, PairingKeys(), True
)
peer = device.Peer(devices.connections[1]) # type: ignore
hap_client = await peer.discover_service_and_create_proxy(
hap.HearingAccessServiceProxy
)
assert hap_client
await hap_client.setup_subscription()
yield hap_client
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_service(hap_client: hap.HearingAccessServiceProxy):
assert (
hap.HearingAidFeatures_from_bytes(await hap_client.server_features.read_value())
== server_features
)
assert (await hap_client.active_preset_index.read_value()) == (foo_preset.index)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_all_presets(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 1, 0xFF])
)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(foo_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(foobar_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(bar_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 1]
) + bytes(unavailable_preset)
await assert_queue_is_empty(hap_client.preset_control_point_indications)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_partial_presets(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 3, 2])
)
assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
foobar_preset
)
assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
bar_preset
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_active_preset_valid(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes(
[hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET, bar_preset.index]
)
)
assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_active_preset_invalid(hap_client: hap.HearingAccessServiceProxy):
with pytest.raises(att.ATT_Error) as e:
await hap_client.hearing_aid_preset_control_point.write_value(
bytes(
[
hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET,
unavailable_preset.index,
]
),
with_response=True,
)
assert e.value.error_code == hap.ErrorCode.PRESET_OPERATION_NOT_POSSIBLE
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_next_preset(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
)
assert (
await hap_client.active_preset_index_notification.get()
) == foobar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (foobar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_next_preset_will_loop_to_first(
hap_client: hap.HearingAccessServiceProxy,
):
async def go_next(new_preset: hap.PresetRecord):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
)
assert (
await hap_client.active_preset_index_notification.get()
) == new_preset.index
assert (await hap_client.active_preset_index.read_value()) == (new_preset.index)
await go_next(foobar_preset)
await go_next(bar_preset)
await go_next(foo_preset)
# Note that there is a invalid preset in the preset record of the server
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_previous_preset_will_loop_to_last(
hap_client: hap.HearingAccessServiceProxy,
):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_PREVIOUS_PRESET])
)
assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)

View File

@@ -60,8 +60,6 @@ from bumble.hci import (
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
HCI_PIN_Code_Request_Reply_Command,
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Supported_Features_Command,
HCI_Read_Local_Version_Information_Command,
@@ -478,51 +476,6 @@ def test_HCI_LE_Setup_ISO_Data_Path_Command():
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
returned_parameters = (
HCI_Read_Local_Supported_Codecs_Command.parse_return_parameters(
bytes([HCI_SUCCESS, 3, CodecID.A_LOG, CodecID.CVSD, CodecID.LINEAR_PCM, 0])
)
)
assert returned_parameters.standard_codec_ids == [
CodecID.A_LOG,
CodecID.CVSD,
CodecID.LINEAR_PCM,
]
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = (
HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
HCI_SUCCESS,
3,
CodecID.A_LOG,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
CodecID.CVSD,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
CodecID.LINEAR_PCM,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
0,
]
)
)
)
assert returned_parameters.standard_codec_ids == [
CodecID.A_LOG,
CodecID.CVSD,
CodecID.LINEAR_PCM,
]
assert returned_parameters.standard_codec_transports == [
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
]
# -----------------------------------------------------------------------------
def test_address():
a = Address('C4:F2:17:1A:1D:BB')

View File

@@ -569,37 +569,6 @@ async def test_sco_setup():
await asyncio.gather(*sco_disconnection_futures)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_hf_batched_response(
hfp_connections: Tuple[hfp.HfProtocol, hfp.AgProtocol]
):
hf, ag = hfp_connections
ag.dlc.write(b'\r\n+BIND: (1,2)\r\n\r\nOK\r\n')
await hf.execute_command("AT+BIND=?", response_type=hfp.AtResponseType.SINGLE)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ag_batched_commands(
hfp_connections: Tuple[hfp.HfProtocol, hfp.AgProtocol]
):
hf, ag = hfp_connections
answer_future = asyncio.get_running_loop().create_future()
ag.on('answer', lambda: answer_future.set_result(None))
hang_up_future = asyncio.get_running_loop().create_future()
ag.on('hang_up', lambda: hang_up_future.set_result(None))
hf.dlc.write(b'ATA\rAT+CHUP\r')
await answer_future
await hang_up_future
# -----------------------------------------------------------------------------
async def run():
await test_slc()

View File

@@ -27,6 +27,7 @@ def test_import():
core,
crypto,
device,
gap,
hci,
hfp,
host,
@@ -40,22 +41,6 @@ def test_import():
utils,
)
from bumble.profiles import (
ascs,
bap,
bass,
battery_service,
cap,
csip,
device_information_service,
gap,
heart_rate_service,
le_audio,
pacs,
pbp,
vcp,
)
assert att
assert bridge
assert company_ids
@@ -63,6 +48,7 @@ def test_import():
assert core
assert crypto
assert device
assert gap
assert hci
assert hfp
assert host
@@ -75,20 +61,6 @@ def test_import():
assert transport
assert utils
assert ascs
assert bap
assert bass
assert battery_service
assert cap
assert csip
assert device_information_service
assert gap
assert heart_rate_service
assert le_audio
assert pacs
assert pbp
assert vcp
# -----------------------------------------------------------------------------
def test_app_imports():

View File

@@ -17,17 +17,13 @@
# -----------------------------------------------------------------------------
import pytest
from unittest import mock
from bumble import smp
from bumble import pairing
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address
from bumble.core import AdvertisingData
from bumble.device import Device
from typing import Optional
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -255,57 +251,6 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'identity_address_type, public_address, random_address, expected_identity_address',
[
(
None,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
None,
Address.ANY,
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.PUBLIC,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.RANDOM,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
],
)
@pytest.mark.asyncio
async def test_send_identity_address_command(
identity_address_type: Optional[pairing.PairingConfig.AddressType],
public_address: Address,
random_address: Address,
expected_identity_address: Address,
):
device = Device()
device.public_address = public_address
device.static_address = random_address
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
with mock.patch.object(session, 'send_command') as mock_method:
session.send_identity_address_command()
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ecc()

View File

@@ -39,9 +39,6 @@ async def vcp_client():
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
# Mock encryption.
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1

View File

@@ -28,18 +28,26 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities,
MediaPacket,
Protocol,
)
from bumble.a2dp import (
make_audio_sink_service_sdp_records,
MPEG_2_AAC_LC_OBJECT_TYPE,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_SNR_ALLOCATION_METHOD,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation,
AacMediaCodecInformation,
)
from bumble.utils import AsyncRunner
from bumble.codecs import AacAudioRtpPacket
from bumble.hci import HCI_Reset_Command
from bumble.rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -64,7 +72,7 @@ class AudioExtractor:
# -----------------------------------------------------------------------------
class AacAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
return AacAudioRtpPacket.from_bytes(packet.payload).to_adts()
return AacAudioRtpPacket(packet.payload).to_adts()
# -----------------------------------------------------------------------------
@@ -122,12 +130,10 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=AacMediaCodecInformation.SamplingFrequency.SF_48000
| AacMediaCodecInformation.SamplingFrequency.SF_44100,
channels=AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO,
media_codec_information=AacMediaCodecInformation.from_lists(
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
sampling_frequencies=[48000, 44100],
channels=[1, 2],
vbr=1,
bitrate=256000,
),
@@ -137,23 +143,20 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_4
| SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| SbcMediaCodecInformation.AllocationMethod.SNR,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
@@ -279,6 +282,9 @@ class Speaker:
mitm=False
)
# Start the controller
await self.device.power_on()
# Listen for Bluetooth connections
self.device.on('connection', self.on_bluetooth_connection)
@@ -289,9 +295,6 @@ class Speaker:
self.avdtp_listener = Listener.for_device(self.device)
self.avdtp_listener.on('connection', self.on_avdtp_connection)
# Start the controller
await self.device.power_on()
print(f'Speaker ready to play, codec={self.codec}')
if connect_address: