Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
5e4055bb6b fix attribute discovery and display 2024-08-11 09:33:53 -07:00
57 changed files with 1650 additions and 6397 deletions

View File

@@ -17,11 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import logging
import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
from typing import cast, Dict, Optional, Tuple
import click
import pyee
@@ -33,7 +32,6 @@ import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
@@ -48,16 +46,14 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
# -----------------------------------------------------------------------------
# Scan For Broadcasts
# Discover Broadcasts
# -----------------------------------------------------------------------------
class BroadcastScanner(pyee.EventEmitter):
class BroadcastDiscoverer:
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str
@@ -83,6 +79,22 @@ class BroadcastScanner(pyee.EventEmitter):
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
@@ -127,8 +139,6 @@ class BroadcastScanner(pyee.EventEmitter):
data,
)
self.emit('update')
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
@@ -217,12 +227,13 @@ class BroadcastScanner(pyee.EventEmitter):
)
def on_sync_establishment(self) -> None:
self.emit('sync_establishment')
self.establishment_timeout_task.cancel()
self.emit('change')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('sync_loss')
self.emit('change')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
@@ -257,21 +268,37 @@ class BroadcastScanner(pyee.EventEmitter):
filter_duplicates: bool,
sync_timeout: float,
):
super().__init__()
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement)
async def start(self) -> None:
async def run(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
async def stop(self) -> None:
await self.device.stop_scanning()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if (
@@ -284,6 +311,7 @@ class BroadcastScanner(pyee.EventEmitter):
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
self.refresh()
return
bumble.utils.AsyncRunner.spawn(
@@ -303,318 +331,41 @@ class BroadcastScanner(pyee.EventEmitter):
name,
periodic_advertising_sync,
)
broadcast.on('change', self.refresh)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.emit('new_broadcast', broadcast)
self.status_message = color(
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner:
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
self.scanner.on('update', self.refresh)
self.status_message = ''
async def start(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.scanner.start()
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
)
broadcast.on('change', self.refresh)
broadcast.on('update', self.refresh)
self.refresh()
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
f'-Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.scanner.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
@contextlib.asynccontextmanager
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
async def run_discover_broadcasts(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device_config = bumble.device.DeviceConfiguration(
name=AURACAST_DEFAULT_DEVICE_NAME,
address=AURACAST_DEFAULT_DEVICE_ADDRESS,
keystore='JsonKeyStore',
)
device = bumble.device.Device.from_config_with_hci(
device_config,
device = bumble.device.Device.with_hci(
AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS,
hci_source,
hci_sink,
)
await device.power_on()
yield device
async def find_broadcast_by_name(
device: bumble.device.Device, name: Optional[str]
) -> BroadcastScanner.Broadcast:
result = asyncio.get_running_loop().create_future()
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
if broadcast.basic_audio_announcement and not result.done():
print(color('Broadcast basic audio announcement received', 'green'))
result.set_result(broadcast)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if name is None or broadcast.name == name:
print(color('Broadcast found:', 'green'), broadcast.name)
broadcast.on('change', lambda: on_broadcast_change(broadcast))
return
print(color(f'Skipping broadcast {broadcast.name}'))
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
broadcast = await result
await scanner.stop()
return broadcast
async def run_scan(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
await scanner.start()
await asyncio.get_running_loop().create_future()
async def run_assist(
broadcast_name: Optional[str],
source_id: Optional[int],
command: str,
transport: str,
address: str,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
# Connect to the server
print(f'=== Connecting to {address}...')
connection = await device.connect(address)
peer = bumble.device.Peer(connection)
print(f'=== Connected to {peer}')
print("+++ Encrypting connection...")
await peer.connection.encrypt()
print("+++ Connection encrypted")
# Request a larger MTU
mtu = AURACAST_DEFAULT_ATT_MTU
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
await peer.request_mtu(mtu)
# Get the BASS service
bass = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy
)
# Check that the service was found
if not bass:
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return
# Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
try:
await broadcast_receive_state.subscribe(
lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
)
)
except bumble.core.ProtocolError as error:
print(
color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
'red',
),
error,
)
value = await broadcast_receive_state.read_value()
print(
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
)
if command == 'monitor-state':
await peer.sustain()
return
if command == 'add-source':
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source(
broadcast.sync.advertiser_address,
broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
# Initiate a PA Sync Transfer
await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning.
await bass.remote_scan_stopped()
await peer.sustain()
return
if command == 'modify-source':
if source_id is None:
print(color('!!! modify-source requires --source-id'))
return
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Modify the source
print(
color('Modifying source:', 'blue'),
source_id,
)
await bass.modify_source(
source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
await peer.sustain()
return
if command == 'remove-source':
if source_id is None:
print(color('!!! remove-source requires --source-id'))
return
# Remove the source
print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id)
await peer.sustain()
return
print(color(f'!!! invalid command {command}'))
async def run_pair(transport: str, address: str) -> None:
async with create_device(transport) as device:
# Connect to the server
print(f'=== Connecting to {address}...')
async with device.connect_as_gatt(address) as peer:
print(f'=== Connected to {peer}')
print("+++ Initiating pairing...")
await peer.connection.pair()
print("+++ Paired")
def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
except bumble.core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError
):
message = bumble.profiles.bass.ApplicationError(error.error_code).name
else:
message = str(error)
print(
color('!!! An error occurred while executing the command:', 'red'), message
)
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run()
await hci_source.terminated
# -----------------------------------------------------------------------------
@@ -628,7 +379,7 @@ def auracast(
ctx.ensure_object(dict)
@auracast.command('scan')
@auracast.command('discover-broadcasts')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@@ -636,50 +387,14 @@ def auracast(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
default=5.0,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def scan(ctx, filter_duplicates, sync_timeout, transport):
"""Scan for public broadcasts"""
run_async(run_scan(filter_duplicates, sync_timeout, transport))
@auracast.command('assist')
@click.option(
'--broadcast-name',
metavar='BROADCAST_NAME',
help='Broadcast Name to tune to',
)
@click.option(
'--source-id',
metavar='SOURCE_ID',
type=int,
help='Source ID (for remove-source command)',
)
@click.option(
'--command',
type=click.Choice(
['monitor-state', 'add-source', 'modify-source', 'remove-source']
),
required=True,
)
@click.argument('transport')
@click.argument('address')
@click.pass_context
def assist(ctx, broadcast_name, source_id, command, transport, address):
"""Scan for broadcasts on behalf of a audio server"""
run_async(run_assist(broadcast_name, source_id, command, transport, address))
@auracast.command('pair')
@click.argument('transport')
@click.argument('address')
@click.pass_context
def pair(ctx, transport, address):
"""Pair with an audio server"""
run_async(run_pair(transport, address))
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
def main():

View File

@@ -168,6 +168,7 @@ class ConsoleApp:
'remote-services': None,
'local-values': None,
'remote-values': None,
'remote-attributes': None,
},
'filter': {
'address': None,
@@ -216,6 +217,7 @@ class ConsoleApp:
)
self.local_values_text = FormattedTextControl()
self.remote_values_text = FormattedTextControl()
self.remote_attributes_text = FormattedTextControl()
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = []
@@ -242,6 +244,12 @@ class ConsoleApp:
Frame(Window(self.remote_values_text), title='Remote Values'),
filter=Condition(lambda: self.top_tab == 'remote-values'),
),
ConditionalContainer(
Frame(
Window(self.remote_attributes_text), title='Remote Attributes'
),
filter=Condition(lambda: self.top_tab == 'remote-attributes'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log'),
@@ -504,6 +512,8 @@ class ConsoleApp:
await self.connected_peer.discover_all()
self.append_to_output('Service Discovery done!')
self.show_remote_services(self.connected_peer.services)
async def discover_attributes(self):
if not self.connected_peer:
self.show_error('not connected')
@@ -514,7 +524,7 @@ class ConsoleApp:
attributes = await self.connected_peer.discover_attributes()
self.append_to_output(f'discovered {len(attributes)} attributes...')
self.show_attributes(attributes)
await self.show_remote_attributes(attributes)
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
if not self.connected_peer:
@@ -659,7 +669,6 @@ class ConsoleApp:
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except bumble.core.TimeoutError:
self.show_error('connection timed out')
@@ -730,19 +739,20 @@ class ConsoleApp:
'remote-services',
'local-values',
'remote-values',
'remote-attributes',
}:
self.top_tab = params[0]
self.ui.invalidate()
while self.top_tab == 'local-values':
await self.do_show_local_values()
await self.show_local_values()
await asyncio.sleep(1)
while self.top_tab == 'remote-values':
await self.do_show_remote_values()
await self.show_remote_values()
await asyncio.sleep(1)
async def do_show_local_values(self):
async def show_local_values(self):
prettytable = PrettyTable()
field_names = ["Service", "Characteristic", "Descriptor"]
@@ -797,7 +807,7 @@ class ConsoleApp:
self.local_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def do_show_remote_values(self):
async def show_remote_values(self):
prettytable = PrettyTable(
field_names=[
"Connection",
@@ -831,6 +841,23 @@ class ConsoleApp:
self.remote_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def show_remote_attributes(self, attributes):
lines = []
for attribute in attributes:
lines.append(('ansimagenta', str(attribute) + "\n"))
try:
value = await attribute.read_value()
lines.append(('ansicyan', value.hex() + "\n"))
except bumble.core.ProtocolError as error:
lines.append(("ansired", f"!!! Protocol Error ({error})\n"))
except bumble.core.TimeoutError:
lines.append(("ansired", "!!! Timeout\n"))
except Exception as error:
lines.append(("ansired", f"!!! Error ({error})\n"))
self.remote_attributes_text.text = lines
self.ui.invalidate()
async def do_get_phy(self, _):
if not self.connected_peer:
self.show_error('not connected')

View File

@@ -1,230 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from typing import Callable, Iterable, Optional
import click
from bumble.core import ProtocolError
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def try_show(function: Callable, *args, **kwargs) -> None:
try:
await function(*args, **kwargs)
except ProtocolError as error:
print(color('ERROR:', 'red'), error)
# -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
# -----------------------------------------------------------------------------
async def show_gap_information(
gap_service: GenericAccessServiceProxy,
):
print(color('### Generic Access Profile', 'yellow'))
if gap_service.device_name:
print(
color(' Device Name:', 'green'),
await gap_service.device_name.read_value(),
)
if gap_service.appearance:
print(
color(' Appearance: ', 'green'),
await gap_service.appearance.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_information(
device_information_service: DeviceInformationServiceProxy,
):
print(color('### Device Information', 'yellow'))
if device_information_service.manufacturer_name:
print(
color(' Manufacturer Name:', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number:
print(
color(' Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number:
print(
color(' Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.firmware_revision:
print(
color(' Firmware Revision:', 'green'),
await device_information_service.firmware_revision.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_battery_level(
battery_service: BatteryServiceProxy,
):
print(color('### Battery Information', 'yellow'))
if battery_service.battery_level:
print(
color(' Battery Level:', 'green'),
await battery_service.battery_level.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_tmas(
tmas: TelephonyAndMediaAudioServiceProxy,
):
print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role:
print(
color(' Role:', 'green'),
await tmas.role.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
await peer.discover_services()
for service in peer.services:
await service.discover_characteristics()
print(color('=== Services ===', 'yellow'))
show_services(peer.services)
print()
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
await try_show(show_gap_information, gap_service)
if device_information_service := peer.create_service_proxy(
DeviceInformationServiceProxy
):
await try_show(show_device_information, device_information_service)
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
await try_show(show_battery_level, battery_service)
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas)
if done is not None:
done.set_result(None)
except asyncio.CancelledError:
print(color('!!! Operation canceled', 'red'))
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await show_device_info(Peer(connection), None)
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on(
'connection',
lambda connection: asyncio.create_task(
show_device_info(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
await done
# -----------------------------------------------------------------------------
@click.command()
@click.option('--device-config', help='Device configuration', type=click.Path())
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
@click.argument('transport')
@click.argument('address-or-name', required=False)
def main(device_config, encrypt, transport, address_or_name):
"""
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -75,15 +75,11 @@ async def async_main(device_config, encrypt, transport, address_or_name):
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await dump_gatt_db(Peer(connection), None)
else:

View File

@@ -33,6 +33,7 @@ import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
@@ -42,7 +43,7 @@ from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
@@ -56,8 +57,8 @@ logger = logging.getLogger(__name__)
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
@@ -78,8 +79,8 @@ def _sink_pac_record() -> pacs.PacRecord:
)
def _source_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
@@ -446,7 +447,7 @@ class Speaker:
)
self.device.add_service(
pacs.PublishedAudioCapabilitiesService(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
@@ -460,10 +461,10 @@ class Speaker:
)
)
ascs_service = ascs.AudioStreamControlService(
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs_service)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
@@ -478,13 +479,13 @@ class Speaker:
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
@@ -494,12 +495,12 @@ class Speaker:
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
if ase.state == ascs.AseStateMachine.State.STREAMING:
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == ascs.AudioRole.SOURCE:
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
@@ -515,10 +516,10 @@ class Speaker:
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == ascs.AudioRole.SOURCE:
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
@@ -531,7 +532,7 @@ class Speaker:
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs_service.ase_state_machines.values():
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()

View File

@@ -46,12 +46,6 @@ from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
@@ -241,10 +235,8 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', lambda keys: on_pairing(connection, keys))
connection.on(
'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
connection.on(
@@ -278,24 +270,19 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing(connection, keys):
def on_pairing(address, keys):
print(color('***-----------------------------------', 'cyan'))
print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
await asyncio.sleep(POST_PAIRING_DELAY)
await connection.disconnect()
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason):
def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
Waiter.instance.terminate()
@@ -306,7 +293,6 @@ async def pair(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -396,18 +382,11 @@ async def pair(
oob_contexts = None
# Set up a pairing config factory
if identity_address == 'public':
identity_address_type = PairingConfig.AddressType.PUBLIC
elif identity_address == 'random':
identity_address_type = PairingConfig.AddressType.RANDOM
else:
identity_address_type = None
device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
identity_address_type=identity_address_type,
delegate=Delegate(mode, connection, io, prompt),
)
@@ -478,10 +457,6 @@ class LogHandler(logging.Handler):
help='Enable CTKD',
show_default=True,
)
@click.option(
'--identity-address',
type=click.Choice(['random', 'public']),
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
@@ -518,7 +493,6 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -544,7 +518,6 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,

View File

@@ -23,7 +23,6 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import functools
import inspect
@@ -42,7 +41,6 @@ from typing import (
from pyee import EventEmitter
from bumble import utils
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
@@ -147,57 +145,43 @@ ATT_RESPONSES = [
ATT_EXECUTE_WRITE_RESPONSE
]
class ErrorCode(utils.OpenIntEnum):
'''
See
ATT_INVALID_HANDLE_ERROR = 0x01
ATT_READ_NOT_PERMITTED_ERROR = 0x02
ATT_WRITE_NOT_PERMITTED_ERROR = 0x03
ATT_INVALID_PDU_ERROR = 0x04
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = 0x05
ATT_REQUEST_NOT_SUPPORTED_ERROR = 0x06
ATT_INVALID_OFFSET_ERROR = 0x07
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = 0x08
ATT_PREPARE_QUEUE_FULL_ERROR = 0x09
ATT_ATTRIBUTE_NOT_FOUND_ERROR = 0x0A
ATT_ATTRIBUTE_NOT_LONG_ERROR = 0x0B
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = 0x0C
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = 0x0D
ATT_UNLIKELY_ERROR_ERROR = 0x0E
ATT_INSUFFICIENT_ENCRYPTION_ERROR = 0x0F
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = 0x10
ATT_INSUFFICIENT_RESOURCES_ERROR = 0x11
* Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
* Core Specification Supplement: Common Profile And Service Error Codes
'''
INVALID_HANDLE = 0x01
READ_NOT_PERMITTED = 0x02
WRITE_NOT_PERMITTED = 0x03
INVALID_PDU = 0x04
INSUFFICIENT_AUTHENTICATION = 0x05
REQUEST_NOT_SUPPORTED = 0x06
INVALID_OFFSET = 0x07
INSUFFICIENT_AUTHORIZATION = 0x08
PREPARE_QUEUE_FULL = 0x09
ATTRIBUTE_NOT_FOUND = 0x0A
ATTRIBUTE_NOT_LONG = 0x0B
INSUFFICIENT_ENCRYPTION_KEY_SIZE = 0x0C
INVALID_ATTRIBUTE_LENGTH = 0x0D
UNLIKELY_ERROR = 0x0E
INSUFFICIENT_ENCRYPTION = 0x0F
UNSUPPORTED_GROUP_TYPE = 0x10
INSUFFICIENT_RESOURCES = 0x11
DATABASE_OUT_OF_SYNC = 0x12
VALUE_NOT_ALLOWED = 0x13
# 0x80 0x9F: Application Error
# 0xE0 0xFF: Common Profile and Service Error Codes
WRITE_REQUEST_REJECTED = 0xFC
CCCD_IMPROPERLY_CONFIGURED = 0xFD
PROCEDURE_ALREADY_IN_PROGRESS = 0xFE
OUT_OF_RANGE = 0xFF
# Backward Compatible Constants
ATT_INVALID_HANDLE_ERROR = ErrorCode.INVALID_HANDLE
ATT_READ_NOT_PERMITTED_ERROR = ErrorCode.READ_NOT_PERMITTED
ATT_WRITE_NOT_PERMITTED_ERROR = ErrorCode.WRITE_NOT_PERMITTED
ATT_INVALID_PDU_ERROR = ErrorCode.INVALID_PDU
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = ErrorCode.INSUFFICIENT_AUTHENTICATION
ATT_REQUEST_NOT_SUPPORTED_ERROR = ErrorCode.REQUEST_NOT_SUPPORTED
ATT_INVALID_OFFSET_ERROR = ErrorCode.INVALID_OFFSET
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = ErrorCode.INSUFFICIENT_AUTHORIZATION
ATT_PREPARE_QUEUE_FULL_ERROR = ErrorCode.PREPARE_QUEUE_FULL
ATT_ATTRIBUTE_NOT_FOUND_ERROR = ErrorCode.ATTRIBUTE_NOT_FOUND
ATT_ATTRIBUTE_NOT_LONG_ERROR = ErrorCode.ATTRIBUTE_NOT_LONG
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION_KEY_SIZE
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = ErrorCode.INVALID_ATTRIBUTE_LENGTH
ATT_UNLIKELY_ERROR_ERROR = ErrorCode.UNLIKELY_ERROR
ATT_INSUFFICIENT_ENCRYPTION_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = ErrorCode.UNSUPPORTED_GROUP_TYPE
ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
ATT_ERROR_NAMES = {
ATT_INVALID_HANDLE_ERROR: 'ATT_INVALID_HANDLE_ERROR',
ATT_READ_NOT_PERMITTED_ERROR: 'ATT_READ_NOT_PERMITTED_ERROR',
ATT_WRITE_NOT_PERMITTED_ERROR: 'ATT_WRITE_NOT_PERMITTED_ERROR',
ATT_INVALID_PDU_ERROR: 'ATT_INVALID_PDU_ERROR',
ATT_INSUFFICIENT_AUTHENTICATION_ERROR: 'ATT_INSUFFICIENT_AUTHENTICATION_ERROR',
ATT_REQUEST_NOT_SUPPORTED_ERROR: 'ATT_REQUEST_NOT_SUPPORTED_ERROR',
ATT_INVALID_OFFSET_ERROR: 'ATT_INVALID_OFFSET_ERROR',
ATT_INSUFFICIENT_AUTHORIZATION_ERROR: 'ATT_INSUFFICIENT_AUTHORIZATION_ERROR',
ATT_PREPARE_QUEUE_FULL_ERROR: 'ATT_PREPARE_QUEUE_FULL_ERROR',
ATT_ATTRIBUTE_NOT_FOUND_ERROR: 'ATT_ATTRIBUTE_NOT_FOUND_ERROR',
ATT_ATTRIBUTE_NOT_LONG_ERROR: 'ATT_ATTRIBUTE_NOT_LONG_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR',
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR: 'ATT_INVALID_ATTRIBUTE_LENGTH_ERROR',
ATT_UNLIKELY_ERROR_ERROR: 'ATT_UNLIKELY_ERROR_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_ERROR',
ATT_UNSUPPORTED_GROUP_TYPE_ERROR: 'ATT_UNSUPPORTED_GROUP_TYPE_ERROR',
ATT_INSUFFICIENT_RESOURCES_ERROR: 'ATT_INSUFFICIENT_RESOURCES_ERROR'
}
ATT_DEFAULT_MTU = 23
@@ -261,9 +245,9 @@ class ATT_PDU:
def pdu_name(op_code):
return name_or_number(ATT_PDU_NAMES, op_code, 2)
@classmethod
def error_name(cls, error_code: int) -> str:
return ErrorCode(error_code).name
@staticmethod
def error_name(error_code):
return name_or_number(ATT_ERROR_NAMES, error_code, 2)
@staticmethod
def subclass(fields):

View File

@@ -580,10 +580,10 @@ class ServiceCapabilities:
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details: Optional[List[str]] = None) -> str:
def to_string(self, details: List[str] = []) -> str:
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ (details or [])
+ details
)
return f'ServiceCapabilities({attributes})'

View File

@@ -148,10 +148,6 @@ class InvalidOperationError(BaseBumbleError, RuntimeError):
"""Invalid Operation Error"""
class NotSupportedError(BaseBumbleError, RuntimeError):
"""Not Supported"""
class OutOfResourcesError(BaseBumbleError, RuntimeError):
"""Out of Resources Error"""

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -151,7 +149,7 @@ QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder:
class G722Decoder(object):
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
@@ -159,7 +157,7 @@ class G722Decoder:
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self) -> None:
def __init__(self):
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
@@ -167,12 +165,12 @@ class G722Decoder:
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
def decode_frame(self, encoded_data) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
def g722_decode(self, result_array, encoded_data) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
@@ -200,16 +198,14 @@ class G722Decoder:
return result_length
def update_decoded_result(
self, xout: int, byte_length: int, byte_array: bytearray
) -> int:
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits: int) -> int:
def lower_sub_band_decoder(self, lower_bits) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
@@ -262,7 +258,7 @@ class G722Decoder:
return rlow
def higher_sub_band_decoder(self, higher_bits: int) -> int:
def higher_sub_band_decoder(self, higher_bits) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
@@ -310,14 +306,14 @@ class G722Decoder:
# -----------------------------------------------------------------------------
class Band:
"""Structure for G722 decode processing."""
class Band(object):
"""Structure for G722 decode proccessing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self) -> None:
def __init__(self):
self._sp = 0
self._sz = 0
self._r = [0] * 3

View File

@@ -51,7 +51,6 @@ from typing_extensions import Self
from pyee import EventEmitter
from bumble import hci
from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from .gatt import Characteristic, Descriptor, Service
@@ -113,7 +112,6 @@ from .hci import (
HCI_LE_Periodic_Advertising_Create_Sync_Command,
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command,
HCI_LE_Periodic_Advertising_Report_Event,
HCI_LE_Periodic_Advertising_Sync_Transfer_Command,
HCI_LE_Periodic_Advertising_Terminate_Sync_Command,
HCI_LE_Enable_Encryption_Command,
HCI_LE_Extended_Advertising_Report_Event,
@@ -170,12 +168,11 @@ from .hci import (
OwnAddressType,
LeFeature,
LeFeatureMask,
LmpFeatureMask,
Phy,
phy_list_to_bits,
)
from .host import Host
from .profiles.gap import GenericAccessService
from .gap import GenericAccessService
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
@@ -190,7 +187,6 @@ from .core import (
InvalidArgumentError,
InvalidOperationError,
InvalidStateError,
NotSupportedError,
OutOfResourcesError,
UnreachableError,
)
@@ -207,13 +203,13 @@ from .keys import (
KeyStore,
PairingKeys,
)
from bumble import pairing
from bumble import gatt_client
from bumble import gatt_server
from bumble import smp
from bumble import sdp
from bumble import l2cap
from bumble import core
from .pairing import PairingConfig
from . import gatt_client
from . import gatt_server
from . import smp
from . import sdp
from . import l2cap
from . import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
@@ -972,25 +968,20 @@ class PeriodicAdvertisingSync(EventEmitter):
response = await self.device.send_command(
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(),
)
if response.return_parameters == HCI_SUCCESS:
if response.status == HCI_SUCCESS:
if self in self.device.periodic_advertising_syncs:
self.device.periodic_advertising_syncs.remove(self)
return
if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST):
self.state = self.State.TERMINATED
if self.sync_handle is not None:
await self.device.send_command(
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
sync_handle=self.sync_handle
)
await self.device.send_command(
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
sync_handle=self.sync_handle
)
)
self.device.periodic_advertising_syncs.remove(self)
async def transfer(self, connection: Connection, service_data: int = 0) -> None:
if self.sync_handle is not None:
await connection.transfer_periodic_sync(self.sync_handle, service_data)
def on_establishment(
self,
status,
@@ -1216,13 +1207,8 @@ class Peer:
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
if proxy := proxy_class.from_client(self.gatt_client):
return cast(_PROXY_CLASS, proxy)
return None
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
@@ -1507,9 +1493,11 @@ class Connection(CompositeEventEmitter):
try:
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
finally:
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
except asyncio.TimeoutError:
pass
self.remove_listener('disconnection', abort.set_result)
self.remove_listener('disconnection_failure', abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)
@@ -1540,11 +1528,6 @@ class Connection(CompositeEventEmitter):
async def get_phy(self):
return await self.device.get_connection_phy(self)
async def transfer_periodic_sync(
self, sync_handle: int, service_data: int = 0
) -> None:
await self.device.transfer_periodic_sync(self, sync_handle, service_data)
# [Classic only]
async def request_remote_name(self):
return await self.device.request_remote_name(self)
@@ -1600,7 +1583,6 @@ class DeviceConfiguration:
classic_ssp_enabled: bool = True
classic_smp_enabled: bool = True
classic_accept_any: bool = True
classic_interlaced_scan_enabled: bool = True
connectable: bool = True
discoverable: bool = True
advertising_data: bytes = bytes(
@@ -1613,8 +1595,6 @@ class DeviceConfiguration:
address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
@@ -1915,7 +1895,6 @@ class Device(CompositeEventEmitter):
self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_smp_enabled = config.classic_smp_enabled
self.classic_interlaced_scan_enabled = config.classic_interlaced_scan_enabled
self.discoverable = config.discoverable
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
@@ -1980,19 +1959,7 @@ class Device(CompositeEventEmitter):
# Setup SMP
self.smp_manager = smp.Manager(
self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
self, pairing_config_factory=lambda connection: PairingConfig()
)
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2216,15 +2183,10 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
),
check_result=True,
)
)
if self.le_enabled:
# Generate a random address if not set.
if self.static_address == Address.ANY_RANDOM:
self.static_address = Address.generate_static_address()
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
@@ -2234,8 +2196,23 @@ class Device(CompositeEventEmitter):
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
else:
self.random_address = self.static_address
# Set the controller address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
)
# Ensure the address bytes can be a static random address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
if self.random_address != Address.ANY_RANDOM:
logger.debug(
@@ -2260,8 +2237,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
),
check_result=True,
)
)
if self.cis_enabled:
@@ -2269,8 +2245,7 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1,
),
check_result=True,
)
)
if self.classic_enabled:
@@ -2293,21 +2268,6 @@ class Device(CompositeEventEmitter):
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
if self.classic_interlaced_scan_enabled:
if self.host.supports_lmp_features(LmpFeatureMask.INTERLACED_PAGE_SCAN):
await self.send_command(
hci.HCI_Write_Page_Scan_Type_Command(page_scan_type=1),
check_result=True,
)
if self.host.supports_lmp_features(
LmpFeatureMask.INTERLACED_INQUIRY_SCAN
):
await self.send_command(
hci.HCI_Write_Inquiry_Scan_Type_Command(scan_type=1),
check_result=True,
)
# Done
self.powered_on = True
@@ -2405,10 +2365,6 @@ class Device(CompositeEventEmitter):
def supports_le_extended_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING)
@property
def supports_le_periodic_advertising(self):
return self.supports_le_features(LeFeatureMask.LE_PERIODIC_ADVERTISING)
async def start_advertising(
self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
@@ -2811,10 +2767,6 @@ class Device(CompositeEventEmitter):
sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT,
filter_duplicates: bool = False,
) -> PeriodicAdvertisingSync:
# Check that the controller supports the feature.
if not self.supports_le_periodic_advertising:
raise NotSupportedError()
# Check that there isn't already an equivalent entry
if any(
sync.advertiser_address == advertiser_address and sync.sid == sid
@@ -3012,47 +2964,18 @@ class Device(CompositeEventEmitter):
] = None,
own_address_type: int = OwnAddressType.RANDOM,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
always_resolve: bool = False,
) -> Connection:
'''
Request a connection to a peer.
When the transport is BLE, this method cannot be called if there is already a
When transport is BLE, this method cannot be called if there is already a
pending connection.
Args:
peer_address:
Address or name of the device to connect to.
If a string is passed:
If the string is an address followed by a `@` suffix, the `always_resolve`
argument is implicitly set to True, so the connection is made to the
address after resolution.
If the string is any other address, the connection is made to that
address (with or without address resolution, depending on the
`always_resolve` argument).
For any other string, a scan for devices using that string as their name
is initiated, and a connection to the first matching device's address
is made. In that case, `always_resolve` is ignored.
connection_parameters_preferences: (BLE only, ignored for BR/EDR)
* None: use the 1M PHY with default parameters
* map: each entry has a PHY as key and a ConnectionParametersPreferences
object as value
connection_parameters_preferences:
(BLE only, ignored for BR/EDR)
* None: use the 1M PHY with default parameters
* map: each entry has a PHY as key and a ConnectionParametersPreferences
object as value
own_address_type:
(BLE only, ignored for BR/EDR)
OwnAddressType.RANDOM to use this device's random address, or
OwnAddressType.PUBLIC to use this device's public address.
timeout:
Maximum time to wait for a connection to be established, in seconds.
Pass None for an unlimited time.
always_resolve:
(BLE only, ignored for BR/EDR)
If True, always initiate a scan, resolving addresses, and connect to the
address that resolves to `peer_address`.
own_address_type: (BLE only)
'''
# Check parameters
@@ -3071,19 +2994,11 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str):
try:
if transport == BT_LE_TRANSPORT and peer_address.endswith('@'):
peer_address = Address.from_string_for_transport(
peer_address[:-1], transport
)
always_resolve = True
logger.debug('forcing address resolution')
else:
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except (InvalidArgumentError, ValueError):
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(
peer_address, transport
@@ -3098,12 +3013,6 @@ class Device(CompositeEventEmitter):
assert isinstance(peer_address, Address)
if transport == BT_LE_TRANSPORT and always_resolve:
logger.debug('resolving address')
peer_address = await self.find_peer_by_identity_address(
peer_address
) # TODO: timeout
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
@@ -3605,26 +3514,15 @@ class Device(CompositeEventEmitter):
check_result=True,
)
async def transfer_periodic_sync(
self, connection: Connection, sync_handle: int, service_data: int = 0
) -> None:
return await self.send_command(
HCI_LE_Periodic_Advertising_Sync_Transfer_Command(
connection_handle=connection.handle,
service_data=service_data,
sync_handle=sync_handle,
),
check_result=True,
)
async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT):
"""
Scan for a peer with a given name and return its address.
Scan for a peer with a give name and return its address and transport
"""
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()
# Scan/inquire with event handlers to handle scan/inquiry results
def on_peer_found(address, ad_data):
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
if local_name is None:
@@ -3633,13 +3531,13 @@ class Device(CompositeEventEmitter):
if local_name.decode('utf-8') == name:
peer_address.set_result(address)
listener = None
handler = None
was_scanning = self.scanning
was_discovering = self.discovering
try:
if transport == BT_LE_TRANSPORT:
event_name = 'advertisement'
listener = self.on(
handler = self.on(
event_name,
lambda advertisement: on_peer_found(
advertisement.address, advertisement.data
@@ -3651,7 +3549,7 @@ class Device(CompositeEventEmitter):
elif transport == BT_BR_EDR_TRANSPORT:
event_name = 'inquiry_result'
listener = self.on(
handler = self.on(
event_name,
lambda address, class_of_device, eir_data, rssi: on_peer_found(
address, eir_data
@@ -3665,67 +3563,21 @@ class Device(CompositeEventEmitter):
return await self.abort_on('flush', peer_address)
finally:
if listener is not None:
self.remove_listener(event_name, listener)
if handler is not None:
self.remove_listener(event_name, handler)
if transport == BT_LE_TRANSPORT and not was_scanning:
await self.stop_scanning()
elif transport == BT_BR_EDR_TRANSPORT and not was_discovering:
await self.stop_discovery()
async def find_peer_by_identity_address(self, identity_address: Address) -> Address:
"""
Scan for a peer with a resolvable address that can be resolved to a given
identity address.
"""
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()
def on_peer_found(address, _):
if address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching public address found for {address}')
peer_address.set_result(address)
return
if address.is_resolvable:
resolved_address = self.address_resolver.resolve(address)
if resolved_address == identity_address:
if not peer_address.done():
logger.debug(f'*** Matching identity found for {address}')
peer_address.set_result(address)
return
was_scanning = self.scanning
event_name = 'advertisement'
listener = None
try:
listener = self.on(
event_name,
lambda advertisement: on_peer_found(
advertisement.address, advertisement.data
),
)
if not self.scanning:
await self.start_scanning(filter_duplicates=True)
return await self.abort_on('flush', peer_address)
finally:
if listener is not None:
self.remove_listener(event_name, listener)
if not was_scanning:
await self.stop_scanning()
@property
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter
def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
self, pairing_config_factory: Callable[[Connection], PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory
@@ -3845,7 +3697,6 @@ class Device(CompositeEventEmitter):
if self.keystore is None:
raise InvalidOperationError('no key store')
logger.debug(f'Looking up key for {connection.peer_address}')
keys = await self.keystore.get(str(connection.peer_address))
if keys is None:
raise InvalidOperationError('keys not found in key store')
@@ -4281,12 +4132,6 @@ class Device(CompositeEventEmitter):
else self.public_address
)
if advertising_set.advertising_parameters.own_address_type in (
OwnAddressType.RANDOM,
OwnAddressType.PUBLIC,
):
connection.self_resolvable_address = None
# Setup auto-restart of the advertising set if needed.
if advertising_set.auto_restart:
connection.once(
@@ -4330,15 +4175,6 @@ class Device(CompositeEventEmitter):
role: int,
connection_parameters: ConnectionParameters,
) -> None:
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if (
peer_resolvable_address == Address.ANY_RANDOM
or not peer_address.is_resolved
):
peer_resolvable_address = None
logger.debug(
f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
@@ -4370,7 +4206,6 @@ class Device(CompositeEventEmitter):
peer_address = resolved_address
self_address = None
own_address_type: Optional[int] = None
if role == HCI_CENTRAL_ROLE:
own_address_type = self.connect_own_address_type
assert own_address_type is not None
@@ -4399,10 +4234,11 @@ class Device(CompositeEventEmitter):
else self.random_address
)
# Some controllers may return local resolvable address even not using address
# generation offloading. Ignore the value to prevent SMP failure.
if own_address_type in (OwnAddressType.RANDOM, OwnAddressType.PUBLIC):
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if peer_resolvable_address == Address.ANY_RANDOM:
peer_resolvable_address = None
# Create a connection.
connection = Connection(

View File

@@ -301,8 +301,6 @@ class Driver(common.Driver):
fw_name: str = ""
config_name: str = ""
POST_RESET_DELAY: float = 0.2
DRIVER_INFOS = [
# 8723A
DriverInfo(
@@ -497,24 +495,12 @@ class Driver(common.Driver):
@classmethod
async def driver_info_for_host(cls, host):
try:
await host.send_command(
HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True
command = HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
logger.error("failed to probe local version information")
return None
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
local_version = response.return_parameters
logger.debug(

View File

@@ -39,7 +39,7 @@ from typing import (
)
from bumble.colors import color
from bumble.core import BaseBumbleError, UUID
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
@@ -238,22 +238,22 @@ GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Call Friendly Name')
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
@@ -275,11 +275,6 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# Hearing Access Service
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC = UUID.from_16_bits(0x2BDC, 'Active Preset Index')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -325,11 +320,6 @@ def show_services(services: Iterable[Service]) -> None:
print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
class InvalidServiceError(BaseBumbleError):
"""The service is not compliant with the spec/profile"""
# -----------------------------------------------------------------------------
class Service(Attribute):
'''
@@ -345,7 +335,7 @@ class Service(Attribute):
uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: Iterable[Service] = (),
included_services: List[Service] = [],
) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
@@ -361,7 +351,7 @@ class Service(Attribute):
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = list(included_services)
self.included_services = included_services[:]
self.characteristics = characteristics[:]
self.primary = primary
@@ -395,7 +385,7 @@ class TemplateService(Service):
self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: Iterable[Service] = (),
included_services: List[Service] = [],
) -> None:
super().__init__(self.UUID, characteristics, primary, included_services)

View File

@@ -68,7 +68,7 @@ from .att import (
ATT_Error,
)
from . import core
from .core import UUID, InvalidStateError
from .core import UUID, InvalidStateError, ProtocolError
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -253,7 +253,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: Type[TemplateService]
@classmethod
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -283,8 +283,6 @@ class Client:
self.services = []
self.cached_values = {}
connection.on('disconnection', self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
@@ -345,7 +343,12 @@ class Client:
self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
@@ -402,7 +405,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -931,7 +934,12 @@ class Client:
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
@@ -953,7 +961,12 @@ class Client:
ATT_INVALID_OFFSET_ERROR,
):
break
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
part = response.part_attribute_value
attribute_value += part
@@ -964,6 +977,7 @@ class Client:
offset += len(part)
self.cache_value(attribute_handle, attribute_value)
# Return the value as bytes
return attribute_value
@@ -1046,7 +1060,12 @@ class Client:
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ATT_Error(error_code=response.error_code, message=response)
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
else:
await self.send_command(
ATT_Write_Command(
@@ -1054,10 +1073,6 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'

View File

@@ -915,7 +915,7 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
# Check that the attribute exists
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
@@ -942,19 +942,11 @@ class Server(EventEmitter):
)
return
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
# Done
response = ATT_Write_Response()
self.send_response(connection, response)
# Accept the value
await attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):

View File

@@ -267,19 +267,6 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
HCI_LE_CS_CONFIG_COMPLETE_EVENT = 0x2F
HCI_LE_CS_PROCEDURE_ENABLE_EVENT = 0x30
HCI_LE_CS_SUBEVENT_RESULT_EVENT = 0x31
HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT = 0x32
HCI_LE_CS_TEST_END_COMPLETE_EVENT = 0x33
HCI_LE_MONITORED_ADVERTISERS_REPORT_EVENT = 0x34
HCI_LE_FRAME_SPACE_UPDATE_EVENT = 0x35
# HCI Command
@@ -586,36 +573,11 @@ HCI_LE_SET_DATA_RELATED_ADDRESS_CHANGES_COMMAND = hci_c
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D)
HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E)
HCI_LE_SET_EXTENDED_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x007F)
HCI_LE_SET_DECISION_DATA_COMMAND = hci_command_op_code(0x08, 0x0080)
HCI_LE_SET_DECISION_INSTRUCTIONS_COMMAND = hci_command_op_code(0x08, 0x0081)
HCI_LE_SET_PERIODIC_ADVERTISING_SUBEVENT_DATA_COMMAND = hci_command_op_code(0x08, 0x0082)
HCI_LE_SET_PERIODIC_ADVERTISING_RESPONSE_DATA_COMMAND = hci_command_op_code(0x08, 0x0083)
HCI_LE_SET_PERIODIC_SYNC_SUBEVENT_COMMAND = hci_command_op_code(0x08, 0x0084)
HCI_LE_EXTENDED_CREATE_CONNECTION_V2_COMMAND = hci_command_op_code(0x08, 0x0085)
HCI_LE_SET_PERIODIC_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x0086)
HCI_LE_READ_ALL_LOCAL_SUPPORTED_FEATURES_COMMAND = hci_command_op_code(0x08, 0x0087)
HCI_LE_READ_ALL_REMOTE_FEATURES_COMMAND = hci_command_op_code(0x08, 0x0088)
HCI_LE_CS_READ_LOCAL_SUPPORTED_CAPABILITIES_COMMAND = hci_command_op_code(0x08, 0x0089)
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMMAND = hci_command_op_code(0x08, 0x008A)
HCI_LE_CS_WRITE_CACHED_REMOTE_SUPPORTED_CAPABILITIES = hci_command_op_code(0x08, 0x008B)
HCI_LE_CS_SECURITY_ENABLE_COMMAND = hci_command_op_code(0x08, 0x008C)
HCI_LE_CS_SET_DEFAULT_SETTINGS_COMMAND = hci_command_op_code(0x08, 0x008D)
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMMAND = hci_command_op_code(0x08, 0x008E)
HCI_LE_CS_WRITE_CACHED_REMOTE_FAE_TABLE_COMMAND = hci_command_op_code(0x08, 0x008F)
HCI_LE_CS_CREATE_CONFIG_COMMAND = hci_command_op_code(0x08, 0x0090)
HCI_LE_CS_REMOVE_CONFIG_COMMAND = hci_command_op_code(0x08, 0x0091)
HCI_LE_CS_SET_CHANNEL_CLASSIFICATION_COMMAND = hci_command_op_code(0x08, 0x0092)
HCI_LE_CS_SET_PROCEDURE_PARAMETERS_COMMAND = hci_command_op_code(0x08, 0x0093)
HCI_LE_CS_PROCEDURE_ENABLE_COMMAND = hci_command_op_code(0x08, 0x0094)
HCI_LE_CS_TEST_COMMAND = hci_command_op_code(0x08, 0x0095)
HCI_LE_CS_TEST_END_COMMAND = hci_command_op_code(0x08, 0x0096)
HCI_LE_SET_HOST_FEATURE_V2_COMMAND = hci_command_op_code(0x08, 0x0097)
HCI_LE_ADD_DEVICE_TO_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x0098)
HCI_LE_REMOVE_DEVICE_FROM_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x0099)
HCI_LE_CLEAR_MONITORED_ADVERTISERS_LIST_COMMAND = hci_command_op_code(0x08, 0x009A)
HCI_LE_READ_MONITORED_ADVERTISERS_LIST_SIZE_COMMAND = hci_command_op_code(0x08, 0x009B)
HCI_LE_ENABLE_MONITORING_ADVERTISERS_COMMAND = hci_command_op_code(0x08, 0x009C)
HCI_LE_FRAME_SPACE_UPDATE_COMMAND = hci_command_op_code(0x08, 0x009D)
# HCI Error Codes
@@ -1188,16 +1150,8 @@ class LeFeature(OpenIntEnum):
CHANNEL_CLASSIFICATION = 39
ADVERTISING_CODING_SELECTION = 40
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 41
DECISION_BASED_ADVERTISING_FILTERING = 42
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 43
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 44
UNSEGMENTED_FRAMED_MODE = 45
CHANNEL_SOUNDING = 46
CHANNEL_SOUNDING_HOST_SUPPORT = 47
CHANNEL_SOUNDING_TONE_QUALITY_INDICATION = 48
LL_EXTENDED_FEATURE_SET = 63
MONITORING_ADVERTISERS = 64
FRAME_SPACE_UPDATE = 65
class LeFeatureMask(enum.IntFlag):
LE_ENCRYPTION = 1 << LeFeature.LE_ENCRYPTION
@@ -1242,16 +1196,8 @@ class LeFeatureMask(enum.IntFlag):
CHANNEL_CLASSIFICATION = 1 << LeFeature.CHANNEL_CLASSIFICATION
ADVERTISING_CODING_SELECTION = 1 << LeFeature.ADVERTISING_CODING_SELECTION
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 1 << LeFeature.ADVERTISING_CODING_SELECTION_HOST_SUPPORT
DECISION_BASED_ADVERTISING_FILTERING = 1 << LeFeature.DECISION_BASED_ADVERTISING_FILTERING
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER
UNSEGMENTED_FRAMED_MODE = 1 << LeFeature.UNSEGMENTED_FRAMED_MODE
CHANNEL_SOUNDING = 1 << LeFeature.CHANNEL_SOUNDING
CHANNEL_SOUNDING_HOST_SUPPORT = 1 << LeFeature.CHANNEL_SOUNDING_HOST_SUPPORT
CHANNEL_SOUNDING_TONE_QUALITY_INDICATION = 1 << LeFeature.CHANNEL_SOUNDING_TONE_QUALITY_INDICATION
LL_EXTENDED_FEATURE_SET = 1 << LeFeature.LL_EXTENDED_FEATURE_SET
MONITORING_ADVERTISERS = 1 << LeFeature.MONITORING_ADVERTISERS
FRAME_SPACE_UPDATE = 1 << LeFeature.FRAME_SPACE_UPDATE
class LmpFeature(enum.IntEnum):
# Page 0 (Legacy LMP features)
@@ -1619,16 +1565,12 @@ class HCI_Object:
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
# Set fields first, because item_count might be 0.
for sub_field_name, _ in field:
result[sub_field_name] = []
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
result[sub_field_name].append(value)
result.setdefault(sub_field_name, []).append(value)
offset += size
continue
@@ -3040,27 +2982,6 @@ class HCI_Write_Inquiry_Scan_Activity_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('authentication_enable', 1),
]
)
class HCI_Read_Authentication_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.23 Read Authentication Enable Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('authentication_enable', 1)])
class HCI_Write_Authentication_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.24 Write Authentication Enable Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
@@ -3101,12 +3022,7 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('synchronous_flow_control_enable', 1),
]
)
@HCI_Command.command()
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
@@ -3275,13 +3191,7 @@ class HCI_Set_Event_Mask_Page_2_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('le_supported_host', 1),
('unused', 1),
]
)
@HCI_Command.command()
class HCI_Read_LE_Host_Support_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.78 Read LE Host Support Command
@@ -3414,39 +3324,13 @@ class HCI_Read_BD_ADDR_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
("status", STATUS_SPEC),
[("standard_codec_ids", 1)],
[("vendor_specific_codec_ids", 4)],
]
)
@HCI_Command.command()
class HCI_Read_Local_Supported_Codecs_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.8 Read Local Supported Codecs Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
("status", STATUS_SPEC),
[("standard_codec_ids", 1), ("standard_codec_transports", 1)],
[("vendor_specific_codec_ids", 4), ("vendor_specific_codec_transports", 1)],
]
)
class HCI_Read_Local_Supported_Codecs_V2_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.8 Read Local Supported Codecs Command
'''
class Transport(OpenIntEnum):
BR_EDR_ACL = 0x00
BR_EDR_SCO = 0x01
LE_CIS = 0x02
LE_BIS = 0x03
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle', 2)],
@@ -3604,12 +3488,7 @@ class HCI_LE_Set_Advertising_Parameters_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('tx_power_level', 1),
]
)
@HCI_Command.command()
class HCI_LE_Read_Advertising_Physical_Channel_Tx_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.6 LE Read Advertising Physical Channel Tx Power Command
@@ -3733,12 +3612,7 @@ class HCI_LE_Create_Connection_Cancel_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('filter_accept_list_size', 1),
]
)
@HCI_Command.command()
class HCI_LE_Read_Filter_Accept_List_Size_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.14 LE Read Filter Accept List Size Command
@@ -3849,12 +3723,7 @@ class HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('le_states', 8),
]
)
@HCI_Command.command()
class HCI_LE_Read_Supported_States_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.27 LE Read Supported States Command
@@ -4660,6 +4529,18 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# -----------------------------------------------------------------------------
@HCI_Command.command(
[
@@ -4695,32 +4576,6 @@ class HCI_LE_Set_Privacy_Mode_Command(HCI_Command):
return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode)
# -----------------------------------------------------------------------------
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
'''
class Enable(enum.IntFlag):
REPORTING_ENABLED = 1 << 0
DUPLICATE_FILTERING_ENABLED = 1 << 1
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('connection_handle', 2),
],
)
class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
@@ -4829,102 +4684,6 @@ class HCI_LE_Reject_CIS_Request_Command(HCI_Command):
reason: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('advertising_handle', 1),
('num_bis', 1),
('sdu_interval', 3),
('max_sdu', 2),
('max_transport_latency', 2),
('rtn', 1),
('phy', 1),
('packing', 1),
('framing', 1),
('encryption', 1),
('broadcast_code', 16),
],
)
class HCI_LE_Create_BIG_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.103 LE Create BIG command
'''
big_handle: int
advertising_handle: int
num_bis: int
sdu_interval: int
max_sdu: int
max_transport_latency: int
rtn: int
phy: int
packing: int
framing: int
encryption: int
broadcast_code: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('reason', {'size': 1, 'mapper': HCI_Constant.error_name}),
],
)
class HCI_LE_Terminate_BIG_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.105 LE Terminate BIG command
'''
big_handle: int
reason: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
('sync_handle', 2),
('encryption', 1),
('broadcast_code', 16),
('mse', 1),
('big_sync_timeout', 2),
[('bis', 1)],
],
)
class HCI_LE_BIG_Create_Sync_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.106 LE BIG Create Sync command
'''
big_handle: int
sync_handle: int
encryption: int
broadcast_code: int
mse: int
big_sync_timeout: int
bis: List[int]
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('big_handle', 1),
],
return_parameters_fields=[
('status', STATUS_SPEC),
('big_handle', 2),
],
)
class HCI_LE_BIG_Terminate_Sync_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.107. LE BIG Terminate Sync command
'''
big_handle: int
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
@@ -5760,27 +5519,6 @@ class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
'''
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('service_data', 2),
('sync_handle', 2),
('advertising_sid', 1),
('advertiser_address_type', Address.ADDRESS_TYPE_SPEC),
('advertiser_address', Address.parse_address_preceded_by_type),
('advertiser_phy', 1),
('periodic_advertising_interval', 2),
('advertiser_clock_accuracy', 1),
]
)
class HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event(HCI_LE_Meta_Event):
'''
See Bluetooth spec @ 7.7.65.24 LE Periodic Advertising Sync Transfer Received Event
'''
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
@@ -6473,23 +6211,6 @@ class HCI_Synchronous_Connection_Changed_Event(HCI_Event):
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('max_tx_latency', 2),
('max_rx_latency', 2),
('min_remote_timeout', 2),
('min_local_timeout', 2),
]
)
class HCI_Sniff_Subrating_Event(HCI_Event):
'''
See Bluetooth spec @ 7.7.37 Sniff Subrating Event
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[

View File

@@ -23,12 +23,13 @@ import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable
from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from bumble.hci import Address
from .hci import Address
# -----------------------------------------------------------------------------
@@ -219,27 +220,31 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
@@ -329,18 +334,17 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
@dataclass
class GetSetStatus:
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -406,6 +410,7 @@ class Device(HID):
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
@@ -423,9 +428,7 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
@@ -439,6 +442,7 @@ class Device(HID):
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -449,7 +453,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
@@ -460,12 +464,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
@@ -475,14 +480,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")

View File

@@ -171,7 +171,7 @@ class Host(AbortableEventEmitter):
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None
self.pending_response = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
@@ -514,9 +514,7 @@ class Host(AbortableEventEmitter):
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: Optional[int] = None
):
async def send_command(self, command, check_result=False):
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -528,13 +526,12 @@ class Host(AbortableEventEmitter):
try:
self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
response = self.pending_response.result()
response = await self.pending_response
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
status = response.status
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
@@ -628,21 +625,14 @@ class Host(AbortableEventEmitter):
# Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet: bytes) -> None:
try:
hci_packet = hci.HCI_Packet.from_bytes(packet)
except Exception as error:
logger.warning(f'!!! error parsing packet from bytes: {error}')
return
hci_packet = hci.HCI_Packet.from_bytes(packet)
if self.ready or (
isinstance(hci_packet, hci.HCI_Command_Complete_Event)
and hci_packet.command_opcode == hci.HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
logger.debug(
f'reset not done, ignoring packet from controller: {hci_packet}'
)
logger.debug('reset not done, ignoring packet from controller')
def on_transport_lost(self):
# Called by the source when the transport has been lost.

View File

@@ -25,10 +25,8 @@ import grpc.aio
from .config import Config
from .device import PandoraDevice
from .host import HostService
from .l2cap import L2CAPService
from .security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
@@ -79,7 +77,6 @@ async def serve(
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
# call hooks if any.
for hook in _SERVICERS_HOOKS:

View File

@@ -1,310 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import grpc
import json
import logging
from asyncio import Queue as AsyncQueue, Future
from . import utils
from .config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
ClassicChannelServer,
ClassicChannelSpec,
LeCreditBasedChannel,
LeCreditBasedChannelServer,
LeCreditBasedChannelSpec,
)
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
Channel as PandoraChannel,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
DisconnectRequest,
DisconnectResponse,
ReceiveRequest,
ReceiveResponse,
SendRequest,
SendResponse,
WaitConnectionRequest,
WaitConnectionResponse,
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
class ChannelContext:
close_future: Future
sdu_queue: AsyncQueue
class L2CAPService(L2CAPServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'L2CAP', 'device': device}
)
self.device = device
self.config = config
self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future()
sdu_queue: AsyncQueue = AsyncQueue()
def on_channel_sdu(sdu):
sdu_queue.put_nowait(sdu)
def on_close():
close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close)
return ChannelContext(close_future, sdu_queue)
@utils.rpc
async def WaitConnection(
self, request: WaitConnectionRequest, context: grpc.ServicerContext
) -> WaitConnectionResponse:
self.log.debug('WaitConnection')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
if channel_type.spsm in self.device.l2cap_channel_manager.le_coc_servers:
l2cap_server = self.device.l2cap_channel_manager.le_coc_servers[
channel_type.spsm
]
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
if channel_type.psm in self.device.l2cap_channel_manager.servers:
l2cap_server = self.device.l2cap_channel_manager.servers[
channel_type.psm
]
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
channel_future: Future[PandoraChannel] = (
asyncio.get_running_loop().create_future()
)
def on_l2cap_channel(l2cap_channel: L2capChannel):
try:
channel_context = self.register_event(l2cap_channel)
pandora_channel: PandoraChannel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
channel_future.set_result(pandora_channel)
except Exception as e:
self.log.error(f'Failed to set channel future: {e}')
if l2cap_server is None:
l2cap_server = self.device.create_l2cap_server(
spec=spec, handler=on_l2cap_channel
)
else:
l2cap_server.on('connection', on_l2cap_channel)
try:
self.log.debug('Waiting for a channel connection.')
pandora_channel: PandoraChannel = await channel_future
return WaitConnectionResponse(channel=pandora_channel)
except Exception as e:
self.log.warning(f'Exception: {e}')
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def WaitDisconnection(
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> WaitDisconnectionResponse:
try:
self.log.debug('WaitDisconnection')
await self.lookup_context(request.channel).close_future
self.log.debug("return WaitDisconnectionResponse")
return WaitDisconnectionResponse(success=empty_pb2.Empty())
except KeyError as e:
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
except Exception as e:
self.log.exception(f'WaitDisonnection failed: {e}')
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Receive(
self, request: ReceiveRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ReceiveResponse, None]:
self.log.debug('Receive')
oneof = request.WhichOneof('source')
self.log.debug(f'Source: {oneof}.')
pandora_channel = getattr(request, oneof)
sdu_queue = self.lookup_context(pandora_channel).sdu_queue
while sdu := await sdu_queue.get():
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
response = ReceiveResponse(data=sdu)
yield response
@utils.rpc
async def Connect(
self, request: ConnectRequest, context: grpc.ServicerContext
) -> ConnectResponse:
self.log.debug('Connect')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
try:
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
l2cap_channel = await connection.create_l2cap_channel(spec=spec)
channel_context = self.register_event(l2cap_channel)
pandora_channel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
return ConnectResponse(channel=pandora_channel)
except OutOfResourcesError as e:
self.log.error(e)
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
except InvalidArgumentError as e:
self.log.error(e)
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> DisconnectResponse:
try:
self.log.debug('Disconnect')
l2cap_channel = self.lookup_channel(request.channel)
if not l2cap_channel:
self.log.warning('Disconnect: Unable to find the channel')
return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
await l2cap_channel.disconnect()
return DisconnectResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return DisconnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Send(
self, request: SendRequest, context: grpc.ServicerContext
) -> SendResponse:
self.log.debug('Send')
try:
oneof = request.WhichOneof('sink')
self.log.debug(f'Sink: {oneof}.')
pandora_channel = getattr(request, oneof)
l2cap_channel = self.lookup_channel(pandora_channel)
if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.send_pdu(request.data)
else:
l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
def craft_pandora_channel(
self,
connection_handle: int,
l2cap_channel: L2capChannel,
) -> PandoraChannel:
parameters = {
"connection_handle": connection_handle,
"source_cid": l2cap_channel.source_cid,
}
cookie = any_pb2.Any()
cookie.value = json.dumps(parameters).encode()
return PandoraChannel(cookie=cookie)
def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
(connection_handle, source_cid) = json.loads(
pandora_channel.cookie.value
).values()
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
return self.channels[pandora_channel.cookie.value]

View File

@@ -1,520 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Audio Input Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import gatt
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
DelegatedCharacteristicAdapter,
TemplateService,
CharacteristicValue,
PackedCharacteristicAdapter,
GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
CHANGE_COUNTER_MAX_VALUE = 0xFF
GAIN_SETTINGS_MIN_VALUE = 0
GAIN_SETTINGS_MAX_VALUE = 255
class ErrorCode(OpenIntEnum):
'''
Cf. 1.6 Application error codes
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
MUTE_DISABLED = 0x82
VALUE_OUT_OF_RANGE = 0x83
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
class Mute(OpenIntEnum):
'''
Cf. 2.2.1.2 Mute Field
'''
NOT_MUTED = 0x00
MUTED = 0x01
DISABLED = 0x02
class GainMode(OpenIntEnum):
'''
Cf. 2.2.1.3 Gain Mode
'''
MANUAL_ONLY = 0x00
AUTOMATIC_ONLY = 0x01
MANUAL = 0x02
AUTOMATIC = 0x03
class AudioInputStatus(OpenIntEnum):
'''
Cf. 3.4 Audio Input Status
'''
INATIVE = 0x00
ACTIVE = 0x01
class AudioInputControlPointOpCode(OpenIntEnum):
'''
Cf. 3.5.1 Audio Input Control Point procedure requirements
'''
SET_GAIN_SETTING = 0x00
UNMUTE = 0x02
MUTE = 0x03
SET_MANUAL_GAIN_MODE = 0x04
SET_AUTOMATIC_GAIN_MODE = 0x05
# -----------------------------------------------------------------------------
@dataclass
class AudioInputState:
'''
Cf. 2.2.1 Audio Input State
'''
gain_settings: int = 0
mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return bytes(
[self.gain_settings, self.mute, self.gain_mode, self.change_counter]
)
@classmethod
def from_bytes(cls, data: bytes):
gain_settings, mute, gain_mode, change_counter = struct.unpack("BBBB", data)
return cls(gain_settings, mute, gain_mode, change_counter)
def update_gain_settings_unit(self, gain_settings_unit: int) -> None:
self.gain_settings_unit = gain_settings_unit
def increment_gain_settings(self, gain_settings_unit: int) -> None:
self.gain_settings += gain_settings_unit
self.increment_change_counter()
def decrement_gain_settings(self) -> None:
self.gain_settings -= self.gain_settings_unit
self.increment_change_counter()
def increment_change_counter(self):
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=bytes(self)
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class GainSettingsProperties:
'''
Cf. 3.2 Gain Settings Properties
'''
gain_settings_unit: int = 1
gain_settings_minimum: int = GAIN_SETTINGS_MIN_VALUE
gain_settings_maximum: int = GAIN_SETTINGS_MAX_VALUE
@classmethod
def from_bytes(cls, data: bytes):
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
struct.unpack('BBB', data)
)
GainSettingsProperties(
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
)
def __bytes__(self) -> bytes:
return bytes(
[
self.gain_settings_unit,
self.gain_settings_minimum,
self.gain_settings_maximum,
]
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class AudioInputControlPoint:
'''
Cf. 3.5.2 Audio Input Control Point
'''
audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
gain_settings_operand = value[2]
await self._set_gain_settings(connection, gain_settings_operand)
elif opcode == AudioInputControlPointOpCode.UNMUTE:
await self._unmute(connection)
elif opcode == AudioInputControlPointOpCode.MUTE:
change_counter_operand = value[1]
await self._mute(connection, change_counter_operand)
elif opcode == AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE:
await self._set_manual_gain_mode(connection)
elif opcode == AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE:
await self._set_automatic_gain_mode(connection)
else:
logger.error(f"OpCode value is incorrect: {opcode}")
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
async def _set_gain_settings(
self, connection: Connection, gain_settings_operand: int
) -> None:
'''Cf. 3.5.2.1 Set Gain Settings Procedure'''
gain_mode = self.audio_input_state.gain_mode
logger.error(f"set_gain_setting: gain_mode: {gain_mode}")
if not (gain_mode == GainMode.MANUAL or gain_mode == GainMode.MANUAL_ONLY):
logger.warning(
"GainMode should be either MANUAL or MANUAL_ONLY Cf Spec Audio Input Control Service 3.5.2.1"
)
return
if (
gain_settings_operand < self.gain_settings_properties.gain_settings_minimum
or gain_settings_operand
> self.gain_settings_properties.gain_settings_maximum
):
logger.error("gain_seetings value out of range")
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
if self.audio_input_state.gain_settings != gain_settings_operand:
self.audio_input_state.gain_settings = gain_settings_operand
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _unmute(self, connection: Connection):
'''Cf. 3.5.2.2 Unmute procedure'''
logger.error(f'unmute: {self.audio_input_state.mute}')
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("unmute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if mute == Mute.NOT_MUTED:
return
self.audio_input_state.mute = Mute.NOT_MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _mute(self, connection: Connection, change_counter_operand: int) -> None:
'''Cf. 3.5.5.2 Mute procedure'''
change_counter = self.audio_input_state.change_counter
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("mute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if mute == Mute.MUTED:
return
self.audio_input_state.mute = Mute.MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_manual_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.4 Set Manual Gain Mode procedure'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.MANUAL:
return
self.audio_input_state.gain_mode = GainMode.MANUAL
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_automatic_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.5 Set Automatic Gain Mode'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.AUTOMATIC:
return
self.audio_input_state.gain_mode = GainMode.AUTOMATIC
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioInputDescription:
'''
Cf. 3.6 Audio Input Description
'''
audio_input_description: str = "Bluetooth"
attribute_value: Optional[CharacteristicValue] = None
@classmethod
def from_bytes(cls, data: bytes):
return cls(audio_input_description=data.decode('utf-8'))
def __bytes__(self) -> bytes:
return self.audio_input_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return self.audio_input_description.encode('utf-8')
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_input_description = value.decode('utf-8')
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value
)
class AICSService(TemplateService):
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
def __init__(
self,
audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_type: str = "local",
audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: Optional[AudioInputDescription] = None,
):
self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state
)
self.gain_settings_properties = (
GainSettingsProperties()
if gain_settings_properties is None
else gain_settings_properties
)
self.audio_input_status = (
AudioInputStatus.ACTIVE
if audio_input_status is None
else audio_input_status
)
self.audio_input_description = (
AudioInputDescription()
if audio_input_description is None
else audio_input_description
)
self.audio_input_control_point: AudioInputControlPoint = AudioInputControlPoint(
self.audio_input_state, self.gain_settings_properties
)
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.audio_input_state.on_read),
),
encode=lambda value: bytes(value),
)
self.audio_input_state.attribute_value = (
self.audio_input_state_characteristic.value
)
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.gain_settings_properties.on_read),
)
)
self.audio_input_type_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=audio_input_type,
)
self.audio_input_status_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([self.audio_input_status]),
)
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.audio_input_control_point.on_write
),
)
)
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.audio_input_description.on_write,
read=self.audio_input_description.on_read,
),
)
)
self.audio_input_description.attribute_value = (
self.audio_input_control_point_characteristic.value
)
super().__init__(
characteristics=[
self.audio_input_state_characteristic, # type: ignore
self.gain_settings_properties_characteristic, # type: ignore
self.audio_input_type_characteristic, # type: ignore
self.audio_input_status_characteristic, # type: ignore
self.audio_input_control_point_characteristic, # type: ignore
self.audio_input_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class AICSServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = AICSService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
self.audio_input_state = DelegatedCharacteristicAdapter(
characteristic=characteristics[0], decode=AudioInputState.from_bytes
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Gain Settings Attribute Characteristic not found"
)
self.gain_settings_properties = PackedCharacteristicAdapter(
characteristics[0],
'BBB',
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Status Characteristic not found"
)
self.audio_input_status = PackedCharacteristicAdapter(
characteristics[0],
'B',
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Control Point Characteristic not found"
)
self.audio_input_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Description Characteristic not found"
)
self.audio_input_description = characteristics[0]

View File

@@ -1,739 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Audio Stream Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import logging
import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
# -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# -----------------------------------------------------------------------------
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
# -----------------------------------------------------------------------------
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = (),
sink_ase_id: Sequence[int] = (),
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

View File

@@ -1,295 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import struct
import logging
from typing import List, Optional, Callable, Union, Any
from bumble import l2cap
from bumble import utils
from bumble import gatt
from bumble import gatt_client
from bumble.core import AdvertisingData
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
_logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class DeviceCapabilities(enum.IntFlag):
IS_RIGHT = 0x01
IS_DUAL = 0x02
CSIS_SUPPORTED = 0x04
class FeatureMap(enum.IntFlag):
LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED = 0x01
class AudioType(utils.OpenIntEnum):
UNKNOWN = 0x00
RINGTONE = 0x01
PHONE_CALL = 0x02
MEDIA = 0x03
class OpCode(utils.OpenIntEnum):
START = 1
STOP = 2
STATUS = 3
class Codec(utils.OpenIntEnum):
G_722_16KHZ = 1
class SupportedCodecs(enum.IntFlag):
G_722_16KHZ = 1 << Codec.G_722_16KHZ
class PeripheralStatus(utils.OpenIntEnum):
"""Status update on the other peripheral."""
OTHER_PERIPHERAL_DISCONNECTED = 1
OTHER_PERIPHERAL_CONNECTED = 2
CONNECTION_PARAMETER_UPDATED = 3
class AudioStatus(utils.OpenIntEnum):
"""Status report field for the audio control point."""
OK = 0
UNKNOWN_COMMAND = -1
ILLEGAL_PARAMETERS = -2
# -----------------------------------------------------------------------------
class AshaService(gatt.TemplateService):
UUID = gatt.GATT_ASHA_SERVICE
audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None
volume: Optional[int] = None
other_state: Optional[int] = None
connection: Optional[Connection] = None
def __init__(
self,
capability: int,
hisyncid: Union[List[int], bytes],
device: Device,
psm: int = 0,
audio_sink: Optional[Callable[[bytes], Any]] = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01,
render_delay_milliseconds: int = 0,
supported_codecs: int = SupportedCodecs.G_722_16KHZ,
) -> None:
if len(hisyncid) != 8:
_logger.warning('HiSyncId should have a length of 8, got %d', len(hisyncid))
self.hisyncid = bytes(hisyncid)
self.capability = capability
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
self.audio_sink = audio_sink
self.protocol_version = protocol_version
self.read_only_properties_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack(
"<BB8sBH2sH",
protocol_version,
capability,
self.hisyncid,
feature_map,
render_delay_milliseconds,
b'\x00\x00',
supported_codecs,
),
)
self.audio_control_point_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
)
self.audio_status_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.READABLE,
bytes([AudioStatus.OK]),
)
self.volume_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_volume_write),
)
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=self._on_connection,
).psm
self.le_psm_out_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self) -> bytes:
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(gatt.GATT_ASHA_SERVICE)
+ bytes([self.protocol_version, self.capability])
+ self.hisyncid[:4],
),
]
)
)
# Handler for audio control commands
async def _on_audio_control_point_write(
self, connection: Optional[Connection], value: bytes
) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == OpCode.START:
# Start
self.active_codec = Codec(value[1])
self.audio_type = AudioType(value[2])
self.volume = value[3]
self.other_state = value[4]
_logger.debug(
f'### START: codec={self.active_codec.name}, '
f'audio_type={self.audio_type.name}, '
f'volume={self.volume}, '
f'other_state={self.other_state}'
)
self.emit('started')
elif opcode == OpCode.STOP:
_logger.debug('### STOP')
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('stopped')
elif opcode == OpCode.STATUS:
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
if self.connection is None and connection:
self.connection = connection
def on_disconnection(_reason) -> None:
self.connection = None
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('disconnected')
connection.once('disconnection', on_disconnection)
# OPCODE_STATUS does not need audio status point update
if opcode != OpCode.STATUS:
await self.device.notify_subscribers(
self.audio_status_characteristic, force=True
)
# Handler for volume control
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0]
self.emit('volume_changed')
# Register an L2CAP CoC server
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
def on_data(data: bytes) -> None:
if self.audio_sink: # pylint: disable=not-callable
self.audio_sink(data)
channel.sink = on_data
# -----------------------------------------------------------------------------
class AshaServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AshaService
read_only_properties_characteristic: gatt_client.CharacteristicProxy
audio_control_point_characteristic: gatt_client.CharacteristicProxy
audio_status_point_characteristic: gatt_client.CharacteristicProxy
volume_characteristic: gatt_client.CharacteristicProxy
psm_characteristic: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
for uuid, attribute_name in (
(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
'read_only_properties_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
'audio_control_point_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
'audio_status_point_characteristic',
),
(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
'volume_characteristic',
),
(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
'psm_characteristic',
),
):
if not (
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
):
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
setattr(self, attribute_name, characteristics[0])

View File

@@ -0,0 +1,193 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from bumble import l2cap
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -24,12 +24,15 @@ import enum
import struct
import functools
import logging
from typing import List
from typing import Optional, List, Union, Type, Dict, Any, Tuple
from typing_extensions import Self
from bumble import core
from bumble import colors
from bumble import device
from bumble import hci
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from bumble.profiles import le_audio
@@ -248,6 +251,231 @@ class AnnouncementType(utils.OpenIntEnum):
TARGETED = 0x01
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
@dataclasses.dataclass
class UnicastServerAdvertisingData:
"""Advertising Data for ASCS."""
@@ -455,6 +683,54 @@ class CodecSpecificConfiguration:
)
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
@dataclasses.dataclass
class BroadcastAudioAnnouncement:
broadcast_id: int
@@ -546,3 +822,603 @@ class BasicAudioAnnouncement:
)
return cls(presentation_delay, subgroups)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = [],
sink_ase_id: Sequence[int] = [],
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

View File

@@ -1,440 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Broadcast Audio Scan Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import ClassVar, List, Optional, Sequence
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
from bumble import utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ApplicationError(utils.OpenIntEnum):
OPCODE_NOT_SUPPORTED = 0x80
INVALID_SOURCE_ID = 0x81
# -----------------------------------------------------------------------------
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
return bytes([len(subgroups)]) + b"".join(
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
+ subgroup.metadata
for subgroup in subgroups
)
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
num_subgroups = data[0]
offset = 1
subgroups = []
for _ in range(num_subgroups):
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
metadata_length = data[offset + 4]
metadata = data[offset + 5 : offset + 5 + metadata_length]
offset += 5 + metadata_length
subgroups.append(SubgroupInfo(bis_sync, metadata))
return subgroups
# -----------------------------------------------------------------------------
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
@dataclasses.dataclass
class SubgroupInfo:
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
bis_sync: int
metadata: bytes
class ControlPointOperation:
class OpCode(utils.OpenIntEnum):
REMOTE_SCAN_STOPPED = 0x00
REMOTE_SCAN_STARTED = 0x01
ADD_SOURCE = 0x02
MODIFY_SOURCE = 0x03
SET_BROADCAST_CODE = 0x04
REMOVE_SOURCE = 0x05
op_code: OpCode
parameters: bytes
@classmethod
def from_bytes(cls, data: bytes) -> ControlPointOperation:
op_code = data[0]
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
return RemoteScanStoppedOperation()
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
return RemoteScanStartedOperation()
if op_code == cls.OpCode.ADD_SOURCE:
return AddSourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.MODIFY_SOURCE:
return ModifySourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.SET_BROADCAST_CODE:
return SetBroadcastCodeOperation.from_parameters(data[1:])
if op_code == cls.OpCode.REMOVE_SOURCE:
return RemoveSourceOperation.from_parameters(data[1:])
raise core.InvalidArgumentError("invalid op code")
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
self.op_code = op_code
self.parameters = parameters
def __bytes__(self) -> bytes:
return bytes([self.op_code]) + self.parameters
class RemoteScanStoppedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
class RemoteScanStartedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
class AddSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
instance.parameters = parameters
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
parameters, 1
)[1]
instance.advertising_sid = parameters[7]
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
instance.subgroups = decode_subgroups(parameters[14:])
return instance
def __init__(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.ADD_SOURCE,
struct.pack(
"<B6sB3sBH",
advertiser_address.address_type,
bytes(advertiser_address),
advertising_sid,
broadcast_id.to_bytes(3, "little"),
pa_sync,
pa_interval,
)
+ encode_subgroups(subgroups),
)
self.advertiser_address = advertiser_address
self.advertising_sid = advertising_sid
self.broadcast_id = broadcast_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class ModifySourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
instance.subgroups = decode_subgroups(parameters[4:])
return instance
def __init__(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.MODIFY_SOURCE,
struct.pack("<BBH", source_id, pa_sync, pa_interval)
+ encode_subgroups(subgroups),
)
self.source_id = source_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class SetBroadcastCodeOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.broadcast_code = parameters[1:17]
return instance
def __init__(
self,
source_id: int,
broadcast_code: bytes,
) -> None:
super().__init__(
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
bytes([source_id]) + broadcast_code,
)
self.source_id = source_id
self.broadcast_code = broadcast_code
if len(self.broadcast_code) != 16:
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
class RemoveSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
return instance
def __init__(self, source_id: int) -> None:
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
self.source_id = source_id
@dataclasses.dataclass
class BroadcastReceiveState:
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
NOT_SYNCHRONIZED_TO_PA = 0x00
SYNCINFO_REQUEST = 0x01
SYNCHRONIZED_TO_PA = 0x02
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
NO_PAST = 0x04
class BigEncryption(utils.OpenIntEnum):
NOT_ENCRYPTED = 0x00
BROADCAST_CODE_REQUIRED = 0x01
DECRYPTING = 0x02
BAD_CODE = 0x03
source_id: int
source_address: hci.Address
source_adv_sid: int
broadcast_id: int
pa_sync_state: PeriodicAdvertisingSyncState
big_encryption: BigEncryption
bad_code: bytes
subgroups: List[SubgroupInfo]
@classmethod
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
if not data:
return None
source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8]
broadcast_id = int.from_bytes(data[9:12], "little")
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
big_encryption = cls.BigEncryption(data[13])
if big_encryption == cls.BigEncryption.BAD_CODE:
bad_code = data[14:30]
subgroups = decode_subgroups(data[30:])
else:
bad_code = b""
subgroups = decode_subgroups(data[14:])
return cls(
source_id,
source_address,
source_adv_sid,
broadcast_id,
pa_sync_state,
big_encryption,
bad_code,
subgroups,
)
def __bytes__(self) -> bytes:
return (
struct.pack(
"<BB6sB3sBB",
self.source_id,
self.source_address.address_type,
bytes(self.source_address),
self.source_adv_sid,
self.broadcast_id.to_bytes(3, "little"),
self.pa_sync_state,
self.big_encryption,
)
+ self.bad_code
+ encode_subgroups(self.subgroups)
)
# -----------------------------------------------------------------------------
class BroadcastAudioScanService(gatt.TemplateService):
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
def __init__(self):
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
write=self.on_broadcast_audio_scan_control_point_write
),
)
self.broadcast_receive_state_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
b"12", # TEST
)
super().__init__([self.battery_level_characteristic])
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
) -> None:
pass
# -----------------------------------------------------------------------------
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BroadcastAudioScanService
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [
gatt.DelegatedCharacteristicAdapter(
characteristic, decode=BroadcastReceiveState.from_bytes
)
for characteristic in characteristics
]
async def send_control_point_operation(
self, operation: ControlPointOperation
) -> None:
await self.broadcast_audio_scan_control_point.write_value(
bytes(operation), with_response=True
)
async def remote_scan_started(self) -> None:
await self.send_control_point_operation(RemoteScanStartedOperation())
async def remote_scan_stopped(self) -> None:
await self.send_control_point_operation(RemoteScanStoppedOperation())
async def add_source(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
AddSourceOperation(
advertiser_address,
advertising_sid,
broadcast_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def modify_source(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
ModifySourceOperation(
source_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def remove_source(self, source_id: int) -> None:
await self.send_control_point_operation(RemoveSourceOperation(source_id))

View File

@@ -1,110 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic Access Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Tuple, Union
from bumble.core import Appearance
from bumble.gatt import (
TemplateService,
Characteristic,
CharacteristicAdapter,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(TemplateService):
UUID = GATT_GENERIC_ACCESS_SERVICE
def __init__(
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
self.appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', appearance_int),
)
super().__init__(
[self.device_name_characteristic, self.appearance_characteristic]
)
# -----------------------------------------------------------------------------
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicAdapter]
appearance: Optional[DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC
):
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
else:
self.device_name = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_APPEARANCE_CHARACTERISTIC
):
self.appearance = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Appearance.from_int(
struct.unpack_from('<H', value, 0)[0],
),
)
else:
self.appearance = None

View File

@@ -1,665 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import functools
from bumble import att, gatt, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Device, Connection
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.hci import Address
from dataclasses import dataclass, field
import logging
from typing import Dict, List, Optional, Set, Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ErrorCode(OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80
WRITE_NAME_NOT_ALLOWED = 0x81
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
PRESET_OPERATION_NOT_POSSIBLE = 0x83
INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00
MONAURAL_HEARING_AID = 0b01
BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off
READ_PRESETS_REQUEST = 0x01
READ_PRESET_RESPONSE = 0x02
PRESET_CHANGED = 0x03
WRITE_PRESET_NAME = 0x04
SET_ACTIVE_PRESET = 0x05
SET_NEXT_PRESET = 0x06
SET_PREVIOUS_PRESET = 0x07
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
@dataclass
class HearingAidFeatures:
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
hearing_aid_type: HearingAidType
preset_synchronization_support: PresetSynchronizationSupport
independent_presets: IndependentPresets
dynamic_presets: DynamicPresets
writable_presets_support: WritablePresetsSupport
def __bytes__(self) -> bytes:
return bytes(
[
(self.hearing_aid_type << 0)
| (self.preset_synchronization_support << 2)
| (self.independent_presets << 3)
| (self.dynamic_presets << 4)
| (self.writable_presets_support << 5)
]
)
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
return HearingAidFeatures(
HearingAidType(data & 0b11),
PresetSynchronizationSupport(data >> 2 & 0b1),
IndependentPresets(data >> 3 & 0b1),
DynamicPresets(data >> 4 & 0b1),
WritablePresetsSupport(data >> 5 & 0b1),
)
@dataclass
class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(OpenIntEnum):
# fmt: off
GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01
PRESET_RECORD_AVAILABLE = 0x02
PRESET_RECORD_UNAVAILABLE = 0x03
@dataclass
class Generic:
prev_index: int
preset_record: PresetRecord
def __bytes__(self) -> bytes:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
additional_parameters_bytes = bytes(self.additional_parameters)
else:
additional_parameters_bytes = bytes([self.additional_parameters])
return (
bytes(
[
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
self.change_id,
is_last,
]
)
+ additional_parameters_bytes
)
class PresetChangedOperationDeleted(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
self.additional_parameters = index
class PresetChangedOperationAvailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
self.additional_parameters = index
class PresetChangedOperationUnavailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
self.additional_parameters = index
@dataclass
class PresetRecord:
'''See Hearing Access Service 2.8. Preset record.'''
@dataclass
class Property:
class Writable(OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1
class IsAvailable(OpenIntEnum):
IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1
writable: Writable = Writable.CAN_BE_WRITTEN
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
def __bytes__(self) -> bytes:
return bytes([self.writable | (self.is_available << 1)])
index: int
name: str
properties: Property = field(default_factory=Property)
def __bytes__(self) -> bytes:
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
def is_available(self) -> bool:
return (
self.properties.is_available
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic
hearing_aid_preset_control_point: gatt.Characteristic
active_preset_index_characteristic: gatt.Characteristic
active_preset_index: int
active_preset_index_per_device: Dict[Address, int]
device: Device
server_features: HearingAidFeatures
preset_records: Dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
preset_changed_operations_history_per_device: Dict[
Address, List[PresetChangedOperation]
]
# Keep an updated list of connected client to send notification to
currently_connected_clients: Set[Connection]
def __init__(
self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
) -> None:
self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False
self.preset_changed_operations_history_per_device = {}
self.currently_connected_clients = set()
self.device = device
self.server_features = features
if len(presets) < 1:
raise InvalidArgumentError(f'Invalid presets: {presets}')
self.preset_records = {}
for p in presets:
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
raise InvalidArgumentError(f'Invalid name: {p.name}')
self.preset_records[p.index] = p
# associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on('connection') # type: ignore
def on_connection(connection: Connection) -> None:
@connection.on('disconnection') # type: ignore
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
# TODO Should we filter on device bonded && device is HAP ?
self.currently_connected_clients.add(connection)
if (
connection.peer_address
not in self.preset_changed_operations_history_per_device
):
self.preset_changed_operations_history_per_device[
connection.peer_address
] = []
return
async def on_connection_async() -> None:
# Send all the PresetChangedOperation that occur when not connected
await self._preset_changed_operation(connection)
# Update the active preset index if needed
await self.notify_active_preset_for_connection(connection)
connection.abort_on('disconnection', on_connection_async())
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(self.server_features),
)
self.hearing_aid_preset_control_point = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.INDICATE
),
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(
write=self._on_write_hearing_aid_preset_control_point
),
)
self.active_preset_index_characteristic = gatt.Characteristic(
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
)
super().__init__(
[
self.hearing_aid_features_characteristic,
self.hearing_aid_preset_control_point,
self.active_preset_index_characteristic,
]
)
def _on_read_active_preset_index(
self, __connection__: Optional[Connection]
) -> bytes:
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
def on_forget(self, addr: Address) -> None:
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Optional[Connection], value: bytes
):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
self.read_presets_request_in_progress = True
start_index = value[1]
if start_index == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
num_presets = value[2]
if num_presets == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
presets = [
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response(
self, connection: Connection, presets: List[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=bytes(
[
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
i == len(presets) - 1,
]
)
+ bytes(preset),
)
finally:
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
self.read_presets_request_in_progress = False
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot set active preset as unavailable')
preset = self.preset_records[index]
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
op_list = self.preset_changed_operations_history_per_device.get(
connection.peer_address, []
)
# Notification will be sent in index order
def get_op_index(op: PresetChangedOperation) -> int:
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
return op.additional_parameters.prev_index
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
try:
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(
self, connection: Optional[Connection], value: bytes
):
assert connection
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.writable
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
):
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
name = value[2:].decode('utf-8')
if not name or len(name) > 40:
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
preset.name = name
await self.generic_update(
PresetChangedOperation(
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
PresetChangedOperation.Generic(index, preset),
)
)
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
if (
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
== self.active_preset_index
):
# Nothing to do, peer is already updated
return
await connection.device.notify_subscriber(
connection,
attribute=self.active_preset_index_characteristic,
value=bytes([self.active_preset_index]),
)
self.active_preset_index_per_device[connection.peer_address] = (
self.active_preset_index
)
async def notify_active_preset(self) -> None:
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.is_available
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
):
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if index == self.active_preset_index:
# Already at correct value
return
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(
self, connection: Optional[Connection], value: bytes
):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
'''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
if not first_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Optional[Connection], value: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(connection, value)
# TODO (low priority) inform other server of the change
async def _on_set_next_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, False)
# TODO (low priority) inform other server of the change
async def _on_set_previous_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, True)
# TODO (low priority) inform other server of the change
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.server_features = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0],
'B',
)
self.hearing_aid_preset_control_point = (
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
)[0]
)
self.active_preset_index = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0],
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
)

View File

@@ -1,210 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Published Audio Capabilities Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import Optional, Sequence, Union
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
from bumble.profiles import le_audio
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]

View File

@@ -1,89 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Telephony and Media Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import logging
import struct
from bumble.gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Role(enum.IntFlag):
CALL_GATEWAY = 1 << 0
CALL_TERMINAL = 1 << 1
UNICAST_MEDIA_SENDER = 1 << 2
UNICAST_MEDIA_RECEIVER = 1 << 3
BROADCAST_MEDIA_SENDER = 1 << 4
BROADCAST_MEDIA_RECEIVER = 1 << 5
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioService(TemplateService):
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
def __init__(self, role: Role):
self.role_characteristic = Characteristic(
GATT_TMAP_ROLE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', int(role)),
)
super().__init__([self.role_characteristic])
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = TelephonyAndMediaAudioService
role: DelegatedCharacteristicAdapter
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0],
),
)

View File

@@ -24,7 +24,7 @@ from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional, Sequence
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
@@ -88,7 +88,6 @@ class VolumeControlService(gatt.TemplateService):
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
included_services: Sequence[gatt.Service] = (),
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
@@ -118,12 +117,11 @@ class VolumeControlService(gatt.TemplateService):
)
super().__init__(
characteristics=[
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
],
included_services=list(included_services),
]
)
@property

View File

@@ -764,9 +764,7 @@ class Session:
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
self.oob_data_flag = (
1 if pairing_config.oob and pairing_config.oob.peer_data else 0
)
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses
self_address = connection.self_resolvable_address or connection.self_address
@@ -1016,10 +1014,8 @@ class Session:
self.send_command(response)
def send_pairing_confirm_command(self) -> None:
if self.pairing_method != PairingMethod.OOB:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
@@ -1082,19 +1078,11 @@ class Session:
)
def send_identity_address_command(self) -> None:
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
identity_address = self.manager.device.public_address
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
identity_address = self.manager.device.static_address
else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
identity_address = {
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
@@ -1739,6 +1727,7 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
@@ -1746,22 +1735,6 @@ class Session:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
elif self.pairing_method == PairingMethod.OOB:
if self.is_initiator:
if self.peer_oob_data:
rb = self.peer_oob_data.r
ra = self.r
else:
rb = bytes(16)
ra = self.r
else:
if self.peer_oob_data:
ra = self.peer_oob_data.r
rb = self.r
else:
ra = bytes(16)
rb = self.r
else:
return

View File

@@ -248,28 +248,26 @@ class AsyncPipeSink:
# -----------------------------------------------------------------------------
class BaseSource:
class ParserSource:
"""
Base class designed to be subclassed by transport-specific source classes
"""
terminated: asyncio.Future[None]
sink: Optional[TransportSink]
parser: PacketParser
def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
self.sink = None
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
self.parser.set_packet_sink(sink)
def on_transport_lost(self) -> None:
if not self.terminated.done():
self.terminated.set_result(None)
if self.sink:
if hasattr(self.sink, 'on_transport_lost'):
self.sink.on_transport_lost()
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
@@ -282,23 +280,6 @@ class BaseSource:
pass
# -----------------------------------------------------------------------------
class ParserSource(BaseSource):
"""
Base class for sources that use an HCI parser.
"""
parser: PacketParser
def __init__(self) -> None:
super().__init__()
self.parser = PacketParser()
def set_packet_sink(self, sink: TransportSink) -> None:
super().set_packet_sink(sink)
self.parser.set_packet_sink(sink)
# -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None:

View File

@@ -23,7 +23,7 @@ import time
import usb.core
import usb.util
from typing import Optional, Set
from typing import Optional
from usb.core import Device as UsbDevice
from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
@@ -46,11 +46,6 @@ RESET_DELAY = 3
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Global
# -----------------------------------------------------------------------------
devices_in_use: Set[int] = set()
# -----------------------------------------------------------------------------
async def open_pyusb_transport(spec: str) -> Transport:
@@ -221,7 +216,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
async def close(self):
await self.source.stop()
await self.sink.stop()
devices_in_use.remove(device.address)
usb.util.release_interface(self.device, 0)
usb_find = usb.core.find
@@ -239,18 +233,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
spec = spec[1:]
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = None
devices = usb_find(
find_all=True, idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
for d in devices:
if d.address in devices_in_use:
continue
device = d
devices_in_use.add(d.address)
break
if device is None:
raise ValueError('device already in use')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
elif '-' in spec:
def device_path(device):

View File

@@ -24,7 +24,7 @@ import platform
import usb1
from bumble.transport.common import Transport, BaseSource, TransportInitError
from bumble.transport.common import Transport, ParserSource, TransportInitError
from bumble import hci
from bumble.colors import color
@@ -139,7 +139,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.packets.put_nowait(packet)
def transfer_callback(self, transfer):
self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
self.acl_out_transfer_ready.release()
status = transfer.getStatus()
# pylint: disable=no-member
@@ -208,7 +208,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError:
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, BaseSource):
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, acl_in, events_in):
super().__init__()
self.device = device
@@ -285,13 +285,7 @@ async def open_usb_transport(spec: str) -> Transport:
packet = await self.queue.get()
except asyncio.CancelledError:
return
if self.sink:
try:
self.sink.on_packet(packet)
except Exception as error:
logger.exception(
color(f'!!! Exception in sink.on_packet: {error}', 'red')
)
self.parser.feed_data(packet)
def close(self):
self.closed = True

View File

@@ -1,95 +0,0 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble ASHA Sink</span>
</div>
</nav>
<br>
<div class="container">
<div class="row">
<div class="col-auto">
<button id="connect-audio" class="btn btn-danger" onclick="connectAudio()">Connect Audio</button>
</div>
</div>
<hr>
<div class="row">
<div class="col-4">
<label class="form-label">Browser Gain</label>
<input type="range" class="form-range" id="browser-gain" min="0" max="2" value="1" step="0.1"
onchange="setGain()">
</div>
</div>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let atResponseInput = document.getElementById("at_response")
let gainInput = document.getElementById('browser-gain')
let log = document.getElementById("log")
let socket = new WebSocket('ws://localhost:8888');
let sampleRate = 0;
let player;
socket.binaryType = "arraybuffer";
socket.onopen = _ => {
log.textContent += 'SOCKET OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'SOCKET CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'SOCKET ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
log.textContent += `<-- ${event.data}\n`
} else {
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 1,
sampleRate: 16000,
flushTime: 20,
});
player.volume(gainInput.value);
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
function setGain() {
if (player != null) {
player.volume(gainInput.value);
}
}
</script>
</div>
</body>
</html>

View File

@@ -1,6 +1,5 @@
{
"name": "Bumble Aid Left",
"address": "F1:F2:F3:F4:F5:F6",
"identity_address_type": 1,
"keystore": "JsonKeyStore"
}
}

View File

@@ -1,6 +1,5 @@
{
"name": "Bumble Aid Right",
"address": "F7:F8:F9:FA:FB:FC",
"identity_address_type": 1,
"keystore": "JsonKeyStore"
}
}

View File

@@ -16,104 +16,192 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import sys
import os
import logging
import websockets
from typing import Optional
from bumble import decoder
from bumble import gatt
from bumble import l2cap
from bumble.core import AdvertisingData
from bumble.device import Device, AdvertisingParameters
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.profiles import asha
ws_connection: Optional[websockets.WebSocketServerProtocol] = None
g722_decoder = decoder.G722Decoder()
from bumble.core import UUID
from bumble.gatt import Service, Characteristic, CharacteristicValue
async def ws_server(ws_client: websockets.WebSocketServerProtocol, path: str):
del path
global ws_connection
ws_connection = ws_client
async for message in ws_client:
print(message)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID(
'6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties'
)
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID(
'f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint'
)
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID(
'38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus'
)
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID(
'2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT'
)
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) != 3:
print('Usage: python run_asha_sink.py <device-config> <transport-spec>')
print('example: python run_asha_sink.py device1.json usb:0')
if len(sys.argv) != 4:
print(
'Usage: python run_asha_sink.py <device-config> <transport-spec> '
'<audio-file>'
)
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
def on_audio_packet(packet: bytes) -> None:
global ws_connection
if ws_connection:
offset = 1
while offset < len(packet):
pcm_data = g722_decoder.decode_frame(packet[offset : offset + 80])
offset += 80
asyncio.get_running_loop().create_task(ws_connection.send(pcm_data))
else:
logging.info("No active client")
# Handler for audio control commands
def on_audio_control_point_write(_connection, value):
print('--- AUDIO CONTROL POINT Write:', value.hex())
opcode = value[0]
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(
f'### START: codec={value[1]}, audio_type={audio_type}, '
f'volume={value[3]}, otherstate={value[4]}'
)
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
asha_service = asha.AshaService(
capability=0,
hisyncid=b'\x01\x02\x03\x04\x05\x06\x07\x08',
device=device,
audio_sink=on_audio_packet,
# Respond with a status
asyncio.create_task(
device.notify_subscribers(audio_status_characteristic, force=True)
)
# Handler for volume control
def on_volume_write(_connection, value):
print('--- VOLUME Write:', value[0])
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
print('<<< Voice data received:', data.hex())
audio_out.write(data)
channel.sink = on_data
server = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=8), handler=on_coc
)
print(f'### LE_PSM_OUT = {server.psm}')
# Add the ASHA service to the GATT server
read_only_properties_characteristic = Characteristic(
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01,
0x02,
0x03,
0x04,
0x05,
0x06,
0x07,
0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00,
0x00, # Render Delay
0x00,
0x00, # RFU
0x02,
0x00, # Codec IDs [G.722 at 16 kHz]
]
),
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', server.psm),
)
device.add_service(
Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic,
],
)
)
device.add_service(asha_service)
# Set the advertising data
advertising_data = (
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device.name, 'utf-8'),
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(ASHA_SERVICE)
+ bytes(
[
0x01, # Protocol Version
0x00, # Capability
0x01,
0x02,
0x03,
0x04, # Truncated HiSyncID
]
),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(gatt.GATT_ASHA_SERVICE),
),
]
)
),
]
)
+ asha_service.get_advertising_data()
)
# Go!
await device.power_on()
await device.create_advertising_set(
auto_restart=True,
advertising_data=advertising_data,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await device.start_advertising(auto_restart=True)
await websockets.serve(ws_server, port=8888)
await hci_transport.source.terminated
await hci_transport.source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(
level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper(),
format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -1,107 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble import att
from bumble.profiles.hap import (
HearingAccessService,
HearingAidFeatures,
HearingAidType,
PresetSynchronizationSupport,
IndependentPresets,
DynamicPresets,
WritablePresetsSupport,
PresetRecord,
)
from bumble.transport import open_transport_or_link
server_features = HearingAidFeatures(
HearingAidType.MONAURAL_HEARING_AID,
PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED,
IndependentPresets.IDENTICAL_PRESET_RECORD,
DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE,
WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED,
)
foo_preset = PresetRecord(1, "foo preset")
bar_preset = PresetRecord(50, "bar preset")
foobar_preset = PresetRecord(5, "foobar preset")
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_hap_server.py <config-file> <transport-spec-for-device>')
print('example: run_hap_server.py device1.json pty:hci_pty')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
hap = HearingAccessService(
device, server_features, [foo_preset, bar_preset, foobar_preset]
)
device.add_service(hap)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble HearingAccessService', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(HearingAccessService.UUID),
),
]
)
)
await device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
)
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -21,7 +21,7 @@ import os
import logging
import json
import websockets
import struct
from bumble.colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
@@ -30,7 +30,9 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID,
UUID,
)
from bumble.hci import Address
from bumble.hid import (
Device as HID_Device,
HID_CONTROL_PSM,
@@ -38,17 +40,20 @@ from bumble.hid import (
Message,
)
from bumble.sdp import (
Client as SDP_Client,
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices
@@ -425,7 +430,7 @@ deviceData = DeviceData()
# -----------------------------------------------------------------------------
async def keyboard_device(hid_device: HID_Device):
async def keyboard_device(hid_device):
# Start a Websocket server to receive events from a web page
async def serve(websocket, _path):
@@ -471,9 +476,9 @@ async def keyboard_device(hid_device: HID_Device):
# limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y))
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
">bb", x, y
)
x_cord = x.to_bytes(signed=True)
y_cord = y.to_bytes(signed=True)
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK:
pass
@@ -510,9 +515,7 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
retValue = hid_device.GetSetStatus()
print(
"GET_REPORT report_id: "
@@ -552,7 +555,8 @@ async def main() -> None:
def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes
) -> HID_Device.GetSetStatus:
):
retValue = hid_device.GetSetStatus()
print(
"SET_REPORT report_id: "
+ str(report_id)
@@ -564,33 +568,33 @@ async def main() -> None:
+ str(data)
)
if report_type == Message.ReportType.FEATURE_REPORT:
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData):
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData):
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3:
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
status = HID_Device.GetSetReturn.SUCCESS
retValue.status = hid_device.GetSetReturn.SUCCESS
else:
status = HID_Device.GetSetReturn.SUCCESS
retValue.status = hid_device.GetSetReturn.SUCCESS
return HID_Device.GetSetStatus(status=status)
return retValue
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
return HID_Device.GetSetStatus(
data=bytes([protocol_mode]),
status=hid_device.GetSetReturn.SUCCESS,
)
def on_get_protocol_cb():
retValue = hid_device.GetSetStatus()
retValue.data = protocol_mode.to_bytes()
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
return HID_Device.GetSetStatus(
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')

View File

@@ -35,13 +35,15 @@ from bumble.hci import (
CodingFormat,
OwnAddressType,
)
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import (
CodecSpecificCapabilities,
ContextType,
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
UnicastServerAdvertisingData,
)
from bumble.profiles.mcp import (
@@ -50,7 +52,7 @@ from bumble.profiles.mcp import (
MediaState,
MediaControlPointOpcode,
)
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link
from typing import Optional

View File

@@ -34,8 +34,8 @@ from bumble.hci import (
CodingFormat,
HCI_IsoDataPacket,
)
from bumble.profiles.ascs import AseStateMachine, AudioStreamControlService
from bumble.profiles.bap import (
AseStateMachine,
UnicastServerAdvertisingData,
CodecSpecificConfiguration,
CodecSpecificCapabilities,
@@ -43,10 +43,13 @@ from bumble.profiles.bap import (
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
)
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.transport import open_transport_or_link

View File

@@ -30,7 +30,6 @@ from bumble.hci import (
CodingFormat,
OwnAddressType,
)
from bumble.profiles.ascs import AudioStreamControlService
from bumble.profiles.bap import (
UnicastServerAdvertisingData,
CodecSpecificCapabilities,
@@ -38,8 +37,10 @@ from bumble.profiles.bap import (
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
)
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.vcp import VolumeControlService

View File

@@ -142,7 +142,7 @@ class MainActivity : ComponentActivity() {
::runRfcommClient,
::runRfcommServer,
::runL2capClient,
::runL2capServer,
::runL2capServer
)
}
@@ -166,8 +166,6 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> runRfcommServer()
"l2cap-client" -> runL2capClient()
"l2cap-server" -> runL2capServer()
"scan-start" -> runScan(true)
"stop-start" -> runScan(false)
}
}
}
@@ -192,11 +190,6 @@ class MainActivity : ComponentActivity() {
l2capServer?.run()
}
private fun runScan(startScan: Boolean) {
val scan = bluetoothAdapter?.let { Scan(it) }
scan?.run(startScan)
}
@SuppressLint("MissingPermission")
fun becomeDiscoverable() {
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
@@ -213,7 +206,7 @@ fun MainView(
runRfcommClient: () -> Unit,
runRfcommServer: () -> Unit,
runL2capClient: () -> Unit,
runL2capServer: () -> Unit,
runL2capServer: () -> Unit
) {
BTBenchTheme {
val scrollState = rememberScrollState()

View File

@@ -1,38 +0,0 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.le.ScanCallback
import android.bluetooth.le.ScanResult
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.scan")
class Scan(val bluetoothAdapter: BluetoothAdapter) {
@SuppressLint("MissingPermission")
fun run(startScan: Boolean) {
var bluetoothLeScanner = bluetoothAdapter.bluetoothLeScanner
val scanCallback = object : ScanCallback() {
override fun onScanResult(callbackType: Int, result: ScanResult?) {
super.onScanResult(callbackType, result)
val device: BluetoothDevice? = result?.device
val deviceName = device?.name ?: "Unknown"
val deviceAddress = device?.address ?: "Unknown"
Log.info("Device found: $deviceName ($deviceAddress)")
}
override fun onScanFailed(errorCode: Int) {
// Handle scan failure
Log.warning("Scan failed with error code: $errorCode")
}
}
if (startScan) {
bluetoothLeScanner?.startScan(scanCallback)
} else {
bluetoothLeScanner?.stopScan(scanCallback)
}
}
}

View File

@@ -1,494 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import pytest
import pytest_asyncio
from bumble import device
from bumble.att import ATT_Error
from bumble.profiles.aics import (
Mute,
AICSService,
AudioInputState,
AICSServiceProxy,
GainMode,
AudioInputStatus,
AudioInputControlPointOpCode,
ErrorCode,
)
from bumble.profiles.vcp import VolumeControlService, VolumeControlServiceProxy
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
aics_service = AICSService()
vcp_service = VolumeControlService(
volume_setting=32, muted=1, volume_flags=1, included_services=[aics_service]
)
@pytest_asyncio.fixture
async def aics_client():
devices = TwoDevices()
devices[0].add_service(vcp_service)
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
vcp_client = await peer.discover_service_and_create_proxy(VolumeControlServiceProxy)
assert vcp_client
included_services = await peer.discover_included_services(vcp_client.service_proxy)
assert included_services
aics_service_discovered = included_services[0]
await peer.discover_characteristics(service=aics_service_discovered)
aics_client = AICSServiceProxy(aics_service_discovered)
yield aics_client
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_service(aics_client: AICSServiceProxy):
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL,
change_counter=0,
)
assert await aics_client.gain_settings_properties.read_value() == (1, 0, 255)
assert await aics_client.audio_input_status.read_value() == (
AudioInputStatus.ACTIVE
)
@pytest.mark.asyncio
async def test_wrong_opcode_raise_error(aics_client: AICSServiceProxy):
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
0xFF,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.OPCODE_NOT_SUPPORTED
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_automatic_only(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
# Unchanged
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.AUTOMATIC_ONLY,
change_counter=0,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_automatic(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
# Unchanged
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=0,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.AUTOMATIC,
change_counter=0,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_MANUAL(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=gain_settings,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL,
change_counter=change_counter,
)
@pytest.mark.asyncio
async def test_set_gain_setting_when_gain_mode_MANUAL_ONLY(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
change_counter = 0
gain_settings = 120
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_GAIN_SETTING,
change_counter,
gain_settings,
]
)
)
assert await aics_client.audio_input_state.read_value() == AudioInputState(
gain_settings=gain_settings,
mute=Mute.NOT_MUTED,
gain_mode=GainMode.MANUAL_ONLY,
change_counter=change_counter,
)
@pytest.mark.asyncio
async def test_unmute_when_muted(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.MUTED
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.UNMUTE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.NOT_MUTED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_unmute_when_mute_disabled(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.DISABLED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.UNMUTE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.MUTE_DISABLED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.DISABLED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_mute_when_not_muted(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.NOT_MUTED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.MUTE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.MUTED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_mute_when_mute_disabled(aics_client: AICSServiceProxy):
aics_service.audio_input_state.mute = Mute.DISABLED
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.MUTE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.MUTE_DISABLED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.mute == Mute.DISABLED
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_automatic(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_already_manual(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
)
)
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_manual_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_manual_gain_mode_when_automatic_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_manual(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
)
)
change_counter += 1
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_already_automatic(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC
aics_service.audio_input_state.change_counter = 0
change_counter = 0
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
)
)
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_manual_only(aics_client: AICSServiceProxy):
aics_service.audio_input_state.gain_mode = GainMode.MANUAL_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.MANUAL_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_set_automatic_gain_mode_when_automatic_only(
aics_client: AICSServiceProxy,
):
aics_service.audio_input_state.gain_mode = GainMode.AUTOMATIC_ONLY
aics_service.audio_input_state.change_counter = 0
change_counter = 0
with pytest.raises(ATT_Error) as e:
await aics_client.audio_input_control_point.write_value(
bytes(
[
AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE,
change_counter,
]
),
with_response=True,
)
assert e.value.error_code == ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED
# No change expected
state: AudioInputState = await aics_client.audio_input_state.read_value()
assert state.gain_mode == GainMode.AUTOMATIC_ONLY
assert state.change_counter == change_counter
@pytest.mark.asyncio
async def test_audio_input_description_initial_value(aics_client: AICSServiceProxy):
description = await aics_client.audio_input_description.read_value()
assert description.decode('utf-8') == "Bluetooth"
@pytest.mark.asyncio
async def test_audio_input_description_write_and_read(aics_client: AICSServiceProxy):
new_description = "Line Input".encode('utf-8')
await aics_client.audio_input_description.write_value(new_description)
description = await aics_client.audio_input_description.read_value()
assert description == new_description

View File

@@ -1,163 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
import struct
from unittest import mock
from bumble import device as bumble_device
from bumble.profiles import asha
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
HI_SYNC_ID = b'\x00\x01\x02\x03\x04\x05\x06\x07'
TIMEOUT = 0.1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_only_properties():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
protocol_version=0x01,
capability=0x02,
feature_map=0x03,
render_delay_milliseconds=0x04,
supported_codecs=0x05,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
read_only_properties = (
await asha_client.read_only_properties_characteristic.read_value()
)
(
protocol_version,
capabilities,
hi_sync_id,
feature_map,
render_delay_milliseconds,
_,
supported_codecs,
) = struct.unpack("<BB8sBHHH", read_only_properties)
assert protocol_version == 0x01
assert capabilities == 0x02
assert hi_sync_id == HI_SYNC_ID
assert feature_map == 0x03
assert render_delay_milliseconds == 0x04
assert supported_codecs == 0x05
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_psm():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
psm = (await asha_client.psm_characteristic.read_value())[0]
assert psm == asha_service.psm
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_audio_control_point_start():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
status_notifications = asyncio.Queue()
await asha_client.audio_status_point_characteristic.subscribe(
status_notifications.put_nowait
)
start_cb = mock.MagicMock()
asha_service.on('started', start_cb)
await asha_client.audio_control_point_characteristic.write_value(
bytes(
[asha.OpCode.START, asha.Codec.G_722_16KHZ, asha.AudioType.MEDIA, 0, 1]
)
)
status = (await asyncio.wait_for(status_notifications.get(), TIMEOUT))[0]
assert status == asha.AudioStatus.OK
start_cb.assert_called_once()
assert asha_service.active_codec == asha.Codec.G_722_16KHZ
assert asha_service.volume == 0
assert asha_service.other_state == 1
assert asha_service.audio_type == asha.AudioType.MEDIA
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_audio_control_point_stop():
devices = TwoDevices()
await devices.setup_connection()
asha_service = asha.AshaService(
hisyncid=HI_SYNC_ID,
device=devices[0],
capability=0,
)
devices[0].add_service(asha_service)
async with bumble_device.Peer(devices.connections[1]) as peer:
asha_client = peer.create_service_proxy(asha.AshaServiceProxy)
assert asha_client
status_notifications = asyncio.Queue()
await asha_client.audio_status_point_characteristic.subscribe(
status_notifications.put_nowait
)
stop_cb = mock.MagicMock()
asha_service.on('stopped', stop_cb)
await asha_client.audio_control_point_characteristic.write_value(
bytes([asha.OpCode.STOP])
)
status = (await asyncio.wait_for(status_notifications.get(), TIMEOUT))[0]
assert status == asha.AudioStatus.OK
stop_cb.assert_called_once()
assert asha_service.active_codec is None
assert asha_service.volume is None
assert asha_service.other_state is None
assert asha_service.audio_type is None

View File

@@ -23,9 +23,8 @@ import logging
from bumble import device
from bumble.hci import CodecID, CodingFormat
from bumble.profiles.ascs import (
AudioStreamControlService,
AudioStreamControlServiceProxy,
from bumble.profiles.bap import (
AudioLocation,
AseStateMachine,
ASE_Operation,
ASE_Config_Codec,
@@ -36,9 +35,6 @@ from bumble.profiles.ascs import (
ASE_Receiver_Stop_Ready,
ASE_Release,
ASE_Update_Metadata,
)
from bumble.profiles.bap import (
AudioLocation,
SupportedFrameDuration,
SupportedSamplingFrequency,
SamplingFrequency,
@@ -46,9 +42,9 @@ from bumble.profiles.bap import (
CodecSpecificCapabilities,
CodecSpecificConfiguration,
ContextType,
)
from bumble.profiles.pacs import (
PacRecord,
AudioStreamControlService,
AudioStreamControlServiceProxy,
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)

View File

@@ -1,146 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from bumble import hci
from bumble.profiles import bass
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def basic_operation_check(operation: bass.ControlPointOperation) -> None:
serialized = bytes(operation)
parsed = bass.ControlPointOperation.from_bytes(serialized)
assert bytes(parsed) == serialized
# -----------------------------------------------------------------------------
def test_operations() -> None:
op1 = bass.RemoteScanStoppedOperation()
basic_operation_check(op1)
op2 = bass.RemoteScanStartedOperation()
basic_operation_check(op2)
op3 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(),
)
basic_operation_check(op3)
op4 = bass.AddSourceOperation(
hci.Address("AA:BB:CC:DD:EE:FF"),
34,
123456,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
456,
(
bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')),
bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')),
),
)
basic_operation_check(op4)
op5 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(),
)
basic_operation_check(op5)
op6 = bass.ModifySourceOperation(
12,
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
567,
(
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
),
)
basic_operation_check(op6)
op7 = bass.SetBroadcastCodeOperation(
7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf')
)
basic_operation_check(op7)
op8 = bass.RemoveSourceOperation(7)
basic_operation_check(op8)
# -----------------------------------------------------------------------------
def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None:
serialized = bytes(brs)
parsed = bass.BroadcastReceiveState.from_bytes(serialized)
assert parsed is not None
assert bytes(parsed) == serialized
def test_broadcast_receive_state() -> None:
subgroups = [
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
]
brs1 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.DECRYPTING,
b'',
subgroups,
)
basic_broadcast_receive_state_check(brs1)
brs2 = bass.BroadcastReceiveState(
12,
hci.Address("AA:BB:CC:DD:EE:FF"),
123,
123456,
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
bass.BroadcastReceiveState.BigEncryption.BAD_CODE,
bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'),
subgroups,
)
basic_broadcast_receive_state_check(brs2)
# -----------------------------------------------------------------------------
async def run():
test_operations()
test_broadcast_receive_state()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -536,16 +536,6 @@ async def test_cis_setup_failure():
await asyncio.wait_for(cis_create_task, _TIMEOUT)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():
devices = TwoDevices()
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
await devices[0].power_on()
assert devices[0].static_address != Address.ANY_RANDOM
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))

View File

@@ -47,10 +47,8 @@ from bumble.att import (
ATT_EXCHANGE_MTU_REQUEST,
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
ATT_Error,
ATT_Error_Response,
ATT_Read_By_Group_Type_Request,
ErrorCode,
)
from .test_utils import async_barrier
@@ -1249,32 +1247,6 @@ async def test_get_characteristics_by_uuid():
assert len(s) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_return_error():
[client, server] = LinkedDevices().devices[:2]
on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED))
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=on_write),
)
service = Service('ABCD', [characteristic])
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
async with Peer(connection) as peer:
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0]
with pytest.raises(ATT_Error) as e:
await c.write_value(b'', with_response=True)
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())

View File

@@ -1,227 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import pytest
import functools
import pytest_asyncio
import logging
import sys
from bumble import att, device
from bumble.profiles import hap
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
foo_preset = hap.PresetRecord(1, "foo preset")
bar_preset = hap.PresetRecord(50, "bar preset")
foobar_preset = hap.PresetRecord(5, "foobar preset")
unavailable_preset = hap.PresetRecord(
78,
"foobar preset",
hap.PresetRecord.Property(
hap.PresetRecord.Property.Writable.CANNOT_BE_WRITTEN,
hap.PresetRecord.Property.IsAvailable.IS_UNAVAILABLE,
),
)
server_features = hap.HearingAidFeatures(
hap.HearingAidType.MONAURAL_HEARING_AID,
hap.PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED,
hap.IndependentPresets.IDENTICAL_PRESET_RECORD,
hap.DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE,
hap.WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED,
)
TIMEOUT = 0.1
async def assert_queue_is_empty(queue: asyncio.Queue):
assert queue.empty()
# Check that nothing is being added during TIMEOUT secondes
if sys.version_info >= (3, 11):
with pytest.raises(TimeoutError):
await asyncio.wait_for(queue.get(), TIMEOUT)
else:
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get(), TIMEOUT)
# -----------------------------------------------------------------------------
@pytest_asyncio.fixture
async def hap_client():
devices = TwoDevices()
devices[0].add_service(
hap.HearingAccessService(
devices[0],
server_features,
[foo_preset, bar_preset, foobar_preset, unavailable_preset],
)
)
await devices.setup_connection()
# TODO negotiate MTU > 49 to not truncate preset names
# Mock encryption.
devices.connections[0].encryption = 1 # type: ignore
devices.connections[1].encryption = 1 # type: ignore
peer = device.Peer(devices.connections[1]) # type: ignore
hap_client = await peer.discover_service_and_create_proxy(
hap.HearingAccessServiceProxy
)
assert hap_client
await hap_client.setup_subscription()
yield hap_client
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_service(hap_client: hap.HearingAccessServiceProxy):
assert (
hap.HearingAidFeatures_from_bytes(await hap_client.server_features.read_value())
== server_features
)
assert (await hap_client.active_preset_index.read_value()) == (foo_preset.index)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_all_presets(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 1, 0xFF])
)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(foo_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(foobar_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
) + bytes(bar_preset)
assert (await hap_client.preset_control_point_indications.get()) == bytes(
[hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 1]
) + bytes(unavailable_preset)
await assert_queue_is_empty(hap_client.preset_control_point_indications)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_partial_presets(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 3, 2])
)
assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
foobar_preset
)
assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
bar_preset
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_active_preset_valid(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes(
[hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET, bar_preset.index]
)
)
assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_active_preset_invalid(hap_client: hap.HearingAccessServiceProxy):
with pytest.raises(att.ATT_Error) as e:
await hap_client.hearing_aid_preset_control_point.write_value(
bytes(
[
hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET,
unavailable_preset.index,
]
),
with_response=True,
)
assert e.value.error_code == hap.ErrorCode.PRESET_OPERATION_NOT_POSSIBLE
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_next_preset(hap_client: hap.HearingAccessServiceProxy):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
)
assert (
await hap_client.active_preset_index_notification.get()
) == foobar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (foobar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_next_preset_will_loop_to_first(
hap_client: hap.HearingAccessServiceProxy,
):
async def go_next(new_preset: hap.PresetRecord):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
)
assert (
await hap_client.active_preset_index_notification.get()
) == new_preset.index
assert (await hap_client.active_preset_index.read_value()) == (new_preset.index)
await go_next(foobar_preset)
await go_next(bar_preset)
await go_next(foo_preset)
# Note that there is a invalid preset in the preset record of the server
await assert_queue_is_empty(hap_client.active_preset_index_notification)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_previous_preset_will_loop_to_last(
hap_client: hap.HearingAccessServiceProxy,
):
await hap_client.hearing_aid_preset_control_point.write_value(
bytes([hap.HearingAidPresetControlPointOpcode.SET_PREVIOUS_PRESET])
)
assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
await assert_queue_is_empty(hap_client.active_preset_index_notification)

View File

@@ -60,8 +60,6 @@ from bumble.hci import (
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
HCI_PIN_Code_Request_Reply_Command,
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Supported_Features_Command,
HCI_Read_Local_Version_Information_Command,
@@ -478,51 +476,6 @@ def test_HCI_LE_Setup_ISO_Data_Path_Command():
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
returned_parameters = (
HCI_Read_Local_Supported_Codecs_Command.parse_return_parameters(
bytes([HCI_SUCCESS, 3, CodecID.A_LOG, CodecID.CVSD, CodecID.LINEAR_PCM, 0])
)
)
assert returned_parameters.standard_codec_ids == [
CodecID.A_LOG,
CodecID.CVSD,
CodecID.LINEAR_PCM,
]
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = (
HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
HCI_SUCCESS,
3,
CodecID.A_LOG,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
CodecID.CVSD,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
CodecID.LINEAR_PCM,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
0,
]
)
)
)
assert returned_parameters.standard_codec_ids == [
CodecID.A_LOG,
CodecID.CVSD,
CodecID.LINEAR_PCM,
]
assert returned_parameters.standard_codec_transports == [
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
]
# -----------------------------------------------------------------------------
def test_address():
a = Address('C4:F2:17:1A:1D:BB')

View File

@@ -27,6 +27,7 @@ def test_import():
core,
crypto,
device,
gap,
hci,
hfp,
host,
@@ -40,22 +41,6 @@ def test_import():
utils,
)
from bumble.profiles import (
ascs,
bap,
bass,
battery_service,
cap,
csip,
device_information_service,
gap,
heart_rate_service,
le_audio,
pacs,
pbp,
vcp,
)
assert att
assert bridge
assert company_ids
@@ -63,6 +48,7 @@ def test_import():
assert core
assert crypto
assert device
assert gap
assert hci
assert hfp
assert host
@@ -75,20 +61,6 @@ def test_import():
assert transport
assert utils
assert ascs
assert bap
assert bass
assert battery_service
assert cap
assert csip
assert device_information_service
assert gap
assert heart_rate_service
assert le_audio
assert pacs
assert pbp
assert vcp
# -----------------------------------------------------------------------------
def test_app_imports():

View File

@@ -17,17 +17,13 @@
# -----------------------------------------------------------------------------
import pytest
from unittest import mock
from bumble import smp
from bumble import pairing
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address
from bumble.core import AdvertisingData
from bumble.device import Device
from typing import Optional
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -255,57 +251,6 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'identity_address_type, public_address, random_address, expected_identity_address',
[
(
None,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
None,
Address.ANY,
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.PUBLIC,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.RANDOM,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
],
)
@pytest.mark.asyncio
async def test_send_identity_address_command(
identity_address_type: Optional[pairing.PairingConfig.AddressType],
public_address: Address,
random_address: Address,
expected_identity_address: Address,
):
device = Device()
device.public_address = public_address
device.static_address = random_address
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
with mock.patch.object(session, 'send_command') as mock_method:
session.send_identity_address_command()
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ecc()

View File

@@ -39,9 +39,6 @@ async def vcp_client():
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
# Mock encryption.
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1