forked from auracaster/bumble_mirror
Compare commits
52 Commits
packageFil
...
gbg/usb-th
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a311c3f723 | ||
|
|
b2bb82a432 | ||
|
|
597560ff80 | ||
|
|
db383bb3e6 | ||
|
|
ccc5bbdad4 | ||
|
|
11c8229017 | ||
|
|
2248f9ae5e | ||
|
|
03c79aacb2 | ||
|
|
0c31713a8e | ||
|
|
9dd814f32e | ||
|
|
ab6e595bcb | ||
|
|
f08fac8c8a | ||
|
|
a699520188 | ||
|
|
f66633459e | ||
|
|
f3b776c343 | ||
|
|
de7b99ce34 | ||
|
|
c0b17d9aff | ||
|
|
3c12be59c5 | ||
|
|
c6b3deb8df | ||
|
|
a0b5606047 | ||
|
|
3824e38485 | ||
|
|
4433184048 | ||
|
|
312fc8db36 | ||
|
|
615691ec81 | ||
|
|
ae8b83f294 | ||
|
|
4a8e21f4db | ||
|
|
3462e7c437 | ||
|
|
0f2e5239ad | ||
|
|
ee48cdc63f | ||
|
|
1c278bec93 | ||
|
|
6a51166af7 | ||
|
|
85d79fa914 | ||
|
|
142bdce94a | ||
|
|
881a5a64b5 | ||
|
|
5aae44b610 | ||
|
|
e3ea167827 | ||
|
|
eec145e095 | ||
|
|
87fa02d6e5 | ||
|
|
ad94c1e1f3 | ||
|
|
546a0bce8d | ||
|
|
cb7ca44a1c | ||
|
|
4081b93407 | ||
|
|
26203ebaad | ||
|
|
3389e3e1ed | ||
|
|
7e1f01c01e | ||
|
|
613e15548a | ||
|
|
e09c91df8e | ||
|
|
df206667b6 | ||
|
|
0f19dd5263 | ||
|
|
b98e4937f3 | ||
|
|
27791cf218 | ||
|
|
f8a2d4f0e0 |
30
.devcontainer/devcontainer.json
Normal file
30
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||||
|
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
||||||
|
{
|
||||||
|
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||||
|
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
||||||
|
|
||||||
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
|
// "features": {},
|
||||||
|
|
||||||
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
|
// "forwardPorts": [],
|
||||||
|
|
||||||
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
|
"postCreateCommand":
|
||||||
|
"python -m pip install '.[build,test,development,documentation]'",
|
||||||
|
|
||||||
|
// Configure tool-specific properties.
|
||||||
|
"customizations": {
|
||||||
|
// Configure properties specific to VS Code.
|
||||||
|
"vscode": {
|
||||||
|
// Add the IDs of extensions you want installed when the container is created.
|
||||||
|
"extensions": [
|
||||||
|
"ms-python.python"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||||
|
// "remoteUser": "root"
|
||||||
|
}
|
||||||
417
apps/auracast.py
417
apps/auracast.py
@@ -17,10 +17,11 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import cast, Dict, Optional, Tuple
|
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import pyee
|
import pyee
|
||||||
@@ -32,6 +33,7 @@ import bumble.device
|
|||||||
import bumble.gatt
|
import bumble.gatt
|
||||||
import bumble.hci
|
import bumble.hci
|
||||||
import bumble.profiles.bap
|
import bumble.profiles.bap
|
||||||
|
import bumble.profiles.bass
|
||||||
import bumble.profiles.pbp
|
import bumble.profiles.pbp
|
||||||
import bumble.transport
|
import bumble.transport
|
||||||
import bumble.utils
|
import bumble.utils
|
||||||
@@ -46,14 +48,16 @@ logger = logging.getLogger(__name__)
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
|
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
|
||||||
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
|
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
|
||||||
|
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
|
||||||
|
AURACAST_DEFAULT_ATT_MTU = 256
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Discover Broadcasts
|
# Scan For Broadcasts
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class BroadcastDiscoverer:
|
class BroadcastScanner(pyee.EventEmitter):
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Broadcast(pyee.EventEmitter):
|
class Broadcast(pyee.EventEmitter):
|
||||||
name: str
|
name: str
|
||||||
@@ -79,22 +83,6 @@ class BroadcastDiscoverer:
|
|||||||
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
|
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
|
||||||
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
|
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
|
||||||
|
|
||||||
self.establishment_timeout_task = asyncio.create_task(
|
|
||||||
self.wait_for_establishment()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def wait_for_establishment(self) -> None:
|
|
||||||
await asyncio.sleep(5.0)
|
|
||||||
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
|
|
||||||
print(
|
|
||||||
color(
|
|
||||||
'!!! Periodic advertisement sync not established in time, '
|
|
||||||
'canceling',
|
|
||||||
'red',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await self.sync.terminate()
|
|
||||||
|
|
||||||
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
||||||
self.rssi = advertisement.rssi
|
self.rssi = advertisement.rssi
|
||||||
for service_data in advertisement.data.get_all(
|
for service_data in advertisement.data.get_all(
|
||||||
@@ -139,6 +127,8 @@ class BroadcastDiscoverer:
|
|||||||
data,
|
data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.emit('update')
|
||||||
|
|
||||||
def print(self) -> None:
|
def print(self) -> None:
|
||||||
print(
|
print(
|
||||||
color('Broadcast:', 'yellow'),
|
color('Broadcast:', 'yellow'),
|
||||||
@@ -227,13 +217,12 @@ class BroadcastDiscoverer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_sync_establishment(self) -> None:
|
def on_sync_establishment(self) -> None:
|
||||||
self.establishment_timeout_task.cancel()
|
self.emit('sync_establishment')
|
||||||
self.emit('change')
|
|
||||||
|
|
||||||
def on_sync_loss(self) -> None:
|
def on_sync_loss(self) -> None:
|
||||||
self.basic_audio_announcement = None
|
self.basic_audio_announcement = None
|
||||||
self.biginfo = None
|
self.biginfo = None
|
||||||
self.emit('change')
|
self.emit('sync_loss')
|
||||||
|
|
||||||
def on_periodic_advertisement(
|
def on_periodic_advertisement(
|
||||||
self, advertisement: bumble.device.PeriodicAdvertisement
|
self, advertisement: bumble.device.PeriodicAdvertisement
|
||||||
@@ -268,37 +257,21 @@ class BroadcastDiscoverer:
|
|||||||
filter_duplicates: bool,
|
filter_duplicates: bool,
|
||||||
sync_timeout: float,
|
sync_timeout: float,
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.filter_duplicates = filter_duplicates
|
self.filter_duplicates = filter_duplicates
|
||||||
self.sync_timeout = sync_timeout
|
self.sync_timeout = sync_timeout
|
||||||
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
|
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
|
||||||
self.status_message = ''
|
|
||||||
device.on('advertisement', self.on_advertisement)
|
device.on('advertisement', self.on_advertisement)
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def start(self) -> None:
|
||||||
self.status_message = color('Scanning...', 'green')
|
|
||||||
await self.device.start_scanning(
|
await self.device.start_scanning(
|
||||||
active=False,
|
active=False,
|
||||||
filter_duplicates=False,
|
filter_duplicates=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def refresh(self) -> None:
|
async def stop(self) -> None:
|
||||||
# Clear the screen from the top
|
await self.device.stop_scanning()
|
||||||
print('\033[H')
|
|
||||||
print('\033[0J')
|
|
||||||
print('\033[H')
|
|
||||||
|
|
||||||
# Print the status message
|
|
||||||
print(self.status_message)
|
|
||||||
print("==========================================")
|
|
||||||
|
|
||||||
# Print all broadcasts
|
|
||||||
for broadcast in self.broadcasts.values():
|
|
||||||
broadcast.print()
|
|
||||||
print('------------------------------------------')
|
|
||||||
|
|
||||||
# Clear the screen to the bottom
|
|
||||||
print('\033[0J')
|
|
||||||
|
|
||||||
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
||||||
if (
|
if (
|
||||||
@@ -311,7 +284,6 @@ class BroadcastDiscoverer:
|
|||||||
|
|
||||||
if broadcast := self.broadcasts.get(advertisement.address):
|
if broadcast := self.broadcasts.get(advertisement.address):
|
||||||
broadcast.update(advertisement)
|
broadcast.update(advertisement)
|
||||||
self.refresh()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
bumble.utils.AsyncRunner.spawn(
|
bumble.utils.AsyncRunner.spawn(
|
||||||
@@ -331,41 +303,318 @@ class BroadcastDiscoverer:
|
|||||||
name,
|
name,
|
||||||
periodic_advertising_sync,
|
periodic_advertising_sync,
|
||||||
)
|
)
|
||||||
broadcast.on('change', self.refresh)
|
|
||||||
broadcast.update(advertisement)
|
broadcast.update(advertisement)
|
||||||
self.broadcasts[advertisement.address] = broadcast
|
self.broadcasts[advertisement.address] = broadcast
|
||||||
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
||||||
self.status_message = color(
|
self.emit('new_broadcast', broadcast)
|
||||||
f'+Found {len(self.broadcasts)} broadcasts', 'green'
|
|
||||||
)
|
|
||||||
self.refresh()
|
|
||||||
|
|
||||||
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
|
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
|
||||||
del self.broadcasts[broadcast.sync.advertiser_address]
|
del self.broadcasts[broadcast.sync.advertiser_address]
|
||||||
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
|
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
|
||||||
|
self.emit('broadcast_loss', broadcast)
|
||||||
|
|
||||||
|
|
||||||
|
class PrintingBroadcastScanner:
|
||||||
|
def __init__(
|
||||||
|
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
|
||||||
|
) -> None:
|
||||||
|
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||||
|
self.scanner.on('new_broadcast', self.on_new_broadcast)
|
||||||
|
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
|
||||||
|
self.scanner.on('update', self.refresh)
|
||||||
|
self.status_message = ''
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self.status_message = color('Scanning...', 'green')
|
||||||
|
await self.scanner.start()
|
||||||
|
|
||||||
|
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
|
||||||
self.status_message = color(
|
self.status_message = color(
|
||||||
f'-Found {len(self.broadcasts)} broadcasts', 'green'
|
f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
|
||||||
|
)
|
||||||
|
broadcast.on('change', self.refresh)
|
||||||
|
broadcast.on('update', self.refresh)
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
|
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
|
||||||
|
self.status_message = color(
|
||||||
|
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
|
||||||
)
|
)
|
||||||
self.refresh()
|
self.refresh()
|
||||||
|
|
||||||
|
def refresh(self) -> None:
|
||||||
|
# Clear the screen from the top
|
||||||
|
print('\033[H')
|
||||||
|
print('\033[0J')
|
||||||
|
print('\033[H')
|
||||||
|
|
||||||
async def run_discover_broadcasts(
|
# Print the status message
|
||||||
filter_duplicates: bool, sync_timeout: float, transport: str
|
print(self.status_message)
|
||||||
) -> None:
|
print("==========================================")
|
||||||
|
|
||||||
|
# Print all broadcasts
|
||||||
|
for broadcast in self.scanner.broadcasts.values():
|
||||||
|
broadcast.print()
|
||||||
|
print('------------------------------------------')
|
||||||
|
|
||||||
|
# Clear the screen to the bottom
|
||||||
|
print('\033[0J')
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
|
||||||
async with await bumble.transport.open_transport(transport) as (
|
async with await bumble.transport.open_transport(transport) as (
|
||||||
hci_source,
|
hci_source,
|
||||||
hci_sink,
|
hci_sink,
|
||||||
):
|
):
|
||||||
device = bumble.device.Device.with_hci(
|
device_config = bumble.device.DeviceConfiguration(
|
||||||
AURACAST_DEFAULT_DEVICE_NAME,
|
name=AURACAST_DEFAULT_DEVICE_NAME,
|
||||||
AURACAST_DEFAULT_DEVICE_ADDRESS,
|
address=AURACAST_DEFAULT_DEVICE_ADDRESS,
|
||||||
|
keystore='JsonKeyStore',
|
||||||
|
)
|
||||||
|
|
||||||
|
device = bumble.device.Device.from_config_with_hci(
|
||||||
|
device_config,
|
||||||
hci_source,
|
hci_source,
|
||||||
hci_sink,
|
hci_sink,
|
||||||
)
|
)
|
||||||
await device.power_on()
|
await device.power_on()
|
||||||
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
|
|
||||||
await discoverer.run()
|
yield device
|
||||||
await hci_source.terminated
|
|
||||||
|
|
||||||
|
async def find_broadcast_by_name(
|
||||||
|
device: bumble.device.Device, name: Optional[str]
|
||||||
|
) -> BroadcastScanner.Broadcast:
|
||||||
|
result = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
|
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||||
|
if broadcast.basic_audio_announcement and not result.done():
|
||||||
|
print(color('Broadcast basic audio announcement received', 'green'))
|
||||||
|
result.set_result(broadcast)
|
||||||
|
|
||||||
|
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||||
|
if name is None or broadcast.name == name:
|
||||||
|
print(color('Broadcast found:', 'green'), broadcast.name)
|
||||||
|
broadcast.on('change', lambda: on_broadcast_change(broadcast))
|
||||||
|
return
|
||||||
|
|
||||||
|
print(color(f'Skipping broadcast {broadcast.name}'))
|
||||||
|
|
||||||
|
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
|
||||||
|
scanner.on('new_broadcast', on_new_broadcast)
|
||||||
|
await scanner.start()
|
||||||
|
|
||||||
|
broadcast = await result
|
||||||
|
await scanner.stop()
|
||||||
|
|
||||||
|
return broadcast
|
||||||
|
|
||||||
|
|
||||||
|
async def run_scan(
|
||||||
|
filter_duplicates: bool, sync_timeout: float, transport: str
|
||||||
|
) -> None:
|
||||||
|
async with create_device(transport) as device:
|
||||||
|
if not device.supports_le_periodic_advertising:
|
||||||
|
print(color('Periodic advertising not supported', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||||
|
await scanner.start()
|
||||||
|
await asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_assist(
|
||||||
|
broadcast_name: Optional[str],
|
||||||
|
source_id: Optional[int],
|
||||||
|
command: str,
|
||||||
|
transport: str,
|
||||||
|
address: str,
|
||||||
|
) -> None:
|
||||||
|
async with create_device(transport) as device:
|
||||||
|
if not device.supports_le_periodic_advertising:
|
||||||
|
print(color('Periodic advertising not supported', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Connect to the server
|
||||||
|
print(f'=== Connecting to {address}...')
|
||||||
|
connection = await device.connect(address)
|
||||||
|
peer = bumble.device.Peer(connection)
|
||||||
|
print(f'=== Connected to {peer}')
|
||||||
|
|
||||||
|
print("+++ Encrypting connection...")
|
||||||
|
await peer.connection.encrypt()
|
||||||
|
print("+++ Connection encrypted")
|
||||||
|
|
||||||
|
# Request a larger MTU
|
||||||
|
mtu = AURACAST_DEFAULT_ATT_MTU
|
||||||
|
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
|
||||||
|
await peer.request_mtu(mtu)
|
||||||
|
|
||||||
|
# Get the BASS service
|
||||||
|
bass = await peer.discover_service_and_create_proxy(
|
||||||
|
bumble.profiles.bass.BroadcastAudioScanServiceProxy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the service was found
|
||||||
|
if not bass:
|
||||||
|
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Subscribe to and read the broadcast receive state characteristics
|
||||||
|
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
|
||||||
|
try:
|
||||||
|
await broadcast_receive_state.subscribe(
|
||||||
|
lambda value, i=i: print(
|
||||||
|
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except bumble.core.ProtocolError as error:
|
||||||
|
print(
|
||||||
|
color(
|
||||||
|
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
|
||||||
|
'red',
|
||||||
|
),
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
value = await broadcast_receive_state.read_value()
|
||||||
|
print(
|
||||||
|
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if command == 'monitor-state':
|
||||||
|
await peer.sustain()
|
||||||
|
return
|
||||||
|
|
||||||
|
if command == 'add-source':
|
||||||
|
# Find the requested broadcast
|
||||||
|
await bass.remote_scan_started()
|
||||||
|
if broadcast_name:
|
||||||
|
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||||
|
else:
|
||||||
|
print(color('Scanning for any broadcast', 'cyan'))
|
||||||
|
broadcast = await find_broadcast_by_name(device, broadcast_name)
|
||||||
|
|
||||||
|
if broadcast.broadcast_audio_announcement is None:
|
||||||
|
print(color('No broadcast audio announcement found', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
broadcast.basic_audio_announcement is None
|
||||||
|
or not broadcast.basic_audio_announcement.subgroups
|
||||||
|
):
|
||||||
|
print(color('No subgroups found', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add the source
|
||||||
|
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
|
||||||
|
await bass.add_source(
|
||||||
|
broadcast.sync.advertiser_address,
|
||||||
|
broadcast.sync.sid,
|
||||||
|
broadcast.broadcast_audio_announcement.broadcast_id,
|
||||||
|
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||||
|
0xFFFF,
|
||||||
|
[
|
||||||
|
bumble.profiles.bass.SubgroupInfo(
|
||||||
|
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||||
|
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initiate a PA Sync Transfer
|
||||||
|
await broadcast.sync.transfer(peer.connection)
|
||||||
|
|
||||||
|
# Notify the sink that we're done scanning.
|
||||||
|
await bass.remote_scan_stopped()
|
||||||
|
|
||||||
|
await peer.sustain()
|
||||||
|
return
|
||||||
|
|
||||||
|
if command == 'modify-source':
|
||||||
|
if source_id is None:
|
||||||
|
print(color('!!! modify-source requires --source-id'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the requested broadcast
|
||||||
|
await bass.remote_scan_started()
|
||||||
|
if broadcast_name:
|
||||||
|
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||||
|
else:
|
||||||
|
print(color('Scanning for any broadcast', 'cyan'))
|
||||||
|
broadcast = await find_broadcast_by_name(device, broadcast_name)
|
||||||
|
|
||||||
|
if broadcast.broadcast_audio_announcement is None:
|
||||||
|
print(color('No broadcast audio announcement found', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
broadcast.basic_audio_announcement is None
|
||||||
|
or not broadcast.basic_audio_announcement.subgroups
|
||||||
|
):
|
||||||
|
print(color('No subgroups found', 'red'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Modify the source
|
||||||
|
print(
|
||||||
|
color('Modifying source:', 'blue'),
|
||||||
|
source_id,
|
||||||
|
)
|
||||||
|
await bass.modify_source(
|
||||||
|
source_id,
|
||||||
|
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||||
|
0xFFFF,
|
||||||
|
[
|
||||||
|
bumble.profiles.bass.SubgroupInfo(
|
||||||
|
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||||
|
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await peer.sustain()
|
||||||
|
return
|
||||||
|
|
||||||
|
if command == 'remove-source':
|
||||||
|
if source_id is None:
|
||||||
|
print(color('!!! remove-source requires --source-id'))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the source
|
||||||
|
print(color('Removing source:', 'blue'), source_id)
|
||||||
|
await bass.remove_source(source_id)
|
||||||
|
await peer.sustain()
|
||||||
|
return
|
||||||
|
|
||||||
|
print(color(f'!!! invalid command {command}'))
|
||||||
|
|
||||||
|
|
||||||
|
async def run_pair(transport: str, address: str) -> None:
|
||||||
|
async with create_device(transport) as device:
|
||||||
|
|
||||||
|
# Connect to the server
|
||||||
|
print(f'=== Connecting to {address}...')
|
||||||
|
async with device.connect_as_gatt(address) as peer:
|
||||||
|
print(f'=== Connected to {peer}')
|
||||||
|
|
||||||
|
print("+++ Initiating pairing...")
|
||||||
|
await peer.connection.pair()
|
||||||
|
print("+++ Paired")
|
||||||
|
|
||||||
|
|
||||||
|
def run_async(async_command: Coroutine) -> None:
|
||||||
|
try:
|
||||||
|
asyncio.run(async_command)
|
||||||
|
except bumble.core.ProtocolError as error:
|
||||||
|
if error.error_namespace == 'att' and error.error_code in list(
|
||||||
|
bumble.profiles.bass.ApplicationError
|
||||||
|
):
|
||||||
|
message = bumble.profiles.bass.ApplicationError(error.error_code).name
|
||||||
|
else:
|
||||||
|
message = str(error)
|
||||||
|
|
||||||
|
print(
|
||||||
|
color('!!! An error occurred while executing the command:', 'red'), message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -379,7 +628,7 @@ def auracast(
|
|||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
|
|
||||||
@auracast.command('discover-broadcasts')
|
@auracast.command('scan')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
|
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
|
||||||
)
|
)
|
||||||
@@ -387,14 +636,50 @@ def auracast(
|
|||||||
'--sync-timeout',
|
'--sync-timeout',
|
||||||
metavar='SYNC_TIMEOUT',
|
metavar='SYNC_TIMEOUT',
|
||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
|
||||||
help='Sync timeout (in seconds)',
|
help='Sync timeout (in seconds)',
|
||||||
)
|
)
|
||||||
@click.argument('transport')
|
@click.argument('transport')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
|
def scan(ctx, filter_duplicates, sync_timeout, transport):
|
||||||
"""Discover public broadcasts"""
|
"""Scan for public broadcasts"""
|
||||||
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
|
run_async(run_scan(filter_duplicates, sync_timeout, transport))
|
||||||
|
|
||||||
|
|
||||||
|
@auracast.command('assist')
|
||||||
|
@click.option(
|
||||||
|
'--broadcast-name',
|
||||||
|
metavar='BROADCAST_NAME',
|
||||||
|
help='Broadcast Name to tune to',
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
'--source-id',
|
||||||
|
metavar='SOURCE_ID',
|
||||||
|
type=int,
|
||||||
|
help='Source ID (for remove-source command)',
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
'--command',
|
||||||
|
type=click.Choice(
|
||||||
|
['monitor-state', 'add-source', 'modify-source', 'remove-source']
|
||||||
|
),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.argument('transport')
|
||||||
|
@click.argument('address')
|
||||||
|
@click.pass_context
|
||||||
|
def assist(ctx, broadcast_name, source_id, command, transport, address):
|
||||||
|
"""Scan for broadcasts on behalf of a audio server"""
|
||||||
|
run_async(run_assist(broadcast_name, source_id, command, transport, address))
|
||||||
|
|
||||||
|
|
||||||
|
@auracast.command('pair')
|
||||||
|
@click.argument('transport')
|
||||||
|
@click.argument('address')
|
||||||
|
@click.pass_context
|
||||||
|
def pair(ctx, transport, address):
|
||||||
|
"""Pair with an audio server"""
|
||||||
|
run_async(run_pair(transport, address))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
|
|||||||
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
||||||
from bumble.gatt_client import CharacteristicProxy
|
from bumble.gatt_client import CharacteristicProxy
|
||||||
from bumble.hci import (
|
from bumble.hci import (
|
||||||
|
Address,
|
||||||
HCI_Constant,
|
HCI_Constant,
|
||||||
HCI_LE_1M_PHY,
|
HCI_LE_1M_PHY,
|
||||||
HCI_LE_2M_PHY,
|
HCI_LE_2M_PHY,
|
||||||
@@ -289,11 +290,7 @@ class ConsoleApp:
|
|||||||
device_config, hci_source, hci_sink
|
device_config, hci_source, hci_sink
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random_address = (
|
random_address = Address.generate_static_address()
|
||||||
f"{random.randint(192,255):02X}" # address is static random
|
|
||||||
)
|
|
||||||
for random_byte in random.sample(range(255), 5):
|
|
||||||
random_address += f":{random_byte:02X}"
|
|
||||||
self.append_to_log(f"Setting random address: {random_address}")
|
self.append_to_log(f"Setting random address: {random_address}")
|
||||||
self.device = Device.with_hci(
|
self.device = Device.with_hci(
|
||||||
'Bumble', random_address, hci_source, hci_sink
|
'Bumble', random_address, hci_source, hci_sink
|
||||||
@@ -503,21 +500,9 @@ class ConsoleApp:
|
|||||||
self.show_error('not connected')
|
self.show_error('not connected')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Discover all services, characteristics and descriptors
|
self.append_to_output('Service Discovery starting...')
|
||||||
self.append_to_output('discovering services...')
|
await self.connected_peer.discover_all()
|
||||||
await self.connected_peer.discover_services()
|
self.append_to_output('Service Discovery done!')
|
||||||
self.append_to_output(
|
|
||||||
f'found {len(self.connected_peer.services)} services,'
|
|
||||||
' discovering characteristics...'
|
|
||||||
)
|
|
||||||
await self.connected_peer.discover_characteristics()
|
|
||||||
self.append_to_output('found characteristics, discovering descriptors...')
|
|
||||||
for service in self.connected_peer.services:
|
|
||||||
for characteristic in service.characteristics:
|
|
||||||
await self.connected_peer.discover_descriptors(characteristic)
|
|
||||||
self.append_to_output('discovery completed')
|
|
||||||
|
|
||||||
self.show_remote_services(self.connected_peer.services)
|
|
||||||
|
|
||||||
async def discover_attributes(self):
|
async def discover_attributes(self):
|
||||||
if not self.connected_peer:
|
if not self.connected_peer:
|
||||||
|
|||||||
230
apps/device_info.py
Normal file
230
apps/device_info.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
# Copyright 2021-2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Iterable, Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from bumble.core import ProtocolError
|
||||||
|
from bumble.colors import color
|
||||||
|
from bumble.device import Device, Peer
|
||||||
|
from bumble.gatt import Service
|
||||||
|
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
|
||||||
|
from bumble.profiles.battery_service import BatteryServiceProxy
|
||||||
|
from bumble.profiles.gap import GenericAccessServiceProxy
|
||||||
|
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
|
||||||
|
from bumble.transport import open_transport_or_link
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def try_show(function: Callable, *args, **kwargs) -> None:
|
||||||
|
try:
|
||||||
|
await function(*args, **kwargs)
|
||||||
|
except ProtocolError as error:
|
||||||
|
print(color('ERROR:', 'red'), error)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def show_services(services: Iterable[Service]) -> None:
|
||||||
|
for service in services:
|
||||||
|
print(color(str(service), 'cyan'))
|
||||||
|
|
||||||
|
for characteristic in service.characteristics:
|
||||||
|
print(color(' ' + str(characteristic), 'magenta'))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def show_gap_information(
|
||||||
|
gap_service: GenericAccessServiceProxy,
|
||||||
|
):
|
||||||
|
print(color('### Generic Access Profile', 'yellow'))
|
||||||
|
|
||||||
|
if gap_service.device_name:
|
||||||
|
print(
|
||||||
|
color(' Device Name:', 'green'),
|
||||||
|
await gap_service.device_name.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if gap_service.appearance:
|
||||||
|
print(
|
||||||
|
color(' Appearance: ', 'green'),
|
||||||
|
await gap_service.appearance.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def show_device_information(
|
||||||
|
device_information_service: DeviceInformationServiceProxy,
|
||||||
|
):
|
||||||
|
print(color('### Device Information', 'yellow'))
|
||||||
|
|
||||||
|
if device_information_service.manufacturer_name:
|
||||||
|
print(
|
||||||
|
color(' Manufacturer Name:', 'green'),
|
||||||
|
await device_information_service.manufacturer_name.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_information_service.model_number:
|
||||||
|
print(
|
||||||
|
color(' Model Number: ', 'green'),
|
||||||
|
await device_information_service.model_number.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_information_service.serial_number:
|
||||||
|
print(
|
||||||
|
color(' Serial Number: ', 'green'),
|
||||||
|
await device_information_service.serial_number.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_information_service.firmware_revision:
|
||||||
|
print(
|
||||||
|
color(' Firmware Revision:', 'green'),
|
||||||
|
await device_information_service.firmware_revision.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def show_battery_level(
|
||||||
|
battery_service: BatteryServiceProxy,
|
||||||
|
):
|
||||||
|
print(color('### Battery Information', 'yellow'))
|
||||||
|
|
||||||
|
if battery_service.battery_level:
|
||||||
|
print(
|
||||||
|
color(' Battery Level:', 'green'),
|
||||||
|
await battery_service.battery_level.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def show_tmas(
|
||||||
|
tmas: TelephonyAndMediaAudioServiceProxy,
|
||||||
|
):
|
||||||
|
print(color('### Telephony And Media Audio Service', 'yellow'))
|
||||||
|
|
||||||
|
if tmas.role:
|
||||||
|
print(
|
||||||
|
color(' Role:', 'green'),
|
||||||
|
await tmas.role.read_value(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||||
|
try:
|
||||||
|
# Discover all services
|
||||||
|
print(color('### Discovering Services and Characteristics', 'magenta'))
|
||||||
|
await peer.discover_services()
|
||||||
|
for service in peer.services:
|
||||||
|
await service.discover_characteristics()
|
||||||
|
|
||||||
|
print(color('=== Services ===', 'yellow'))
|
||||||
|
show_services(peer.services)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
|
||||||
|
await try_show(show_gap_information, gap_service)
|
||||||
|
|
||||||
|
if device_information_service := peer.create_service_proxy(
|
||||||
|
DeviceInformationServiceProxy
|
||||||
|
):
|
||||||
|
await try_show(show_device_information, device_information_service)
|
||||||
|
|
||||||
|
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
|
||||||
|
await try_show(show_battery_level, battery_service)
|
||||||
|
|
||||||
|
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
|
||||||
|
await try_show(show_tmas, tmas)
|
||||||
|
|
||||||
|
if done is not None:
|
||||||
|
done.set_result(None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print(color('!!! Operation canceled', 'red'))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def async_main(device_config, encrypt, transport, address_or_name):
|
||||||
|
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||||
|
|
||||||
|
# Create a device
|
||||||
|
if device_config:
|
||||||
|
device = Device.from_config_file_with_hci(
|
||||||
|
device_config, hci_source, hci_sink
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
device = Device.with_hci(
|
||||||
|
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||||
|
)
|
||||||
|
await device.power_on()
|
||||||
|
|
||||||
|
if address_or_name:
|
||||||
|
# Connect to the target peer
|
||||||
|
print(color('>>> Connecting...', 'green'))
|
||||||
|
connection = await device.connect(address_or_name)
|
||||||
|
print(color('>>> Connected', 'green'))
|
||||||
|
|
||||||
|
# Encrypt the connection if required
|
||||||
|
if encrypt:
|
||||||
|
print(color('+++ Encrypting connection...', 'blue'))
|
||||||
|
await connection.encrypt()
|
||||||
|
print(color('+++ Encryption established', 'blue'))
|
||||||
|
|
||||||
|
await show_device_info(Peer(connection), None)
|
||||||
|
else:
|
||||||
|
# Wait for a connection
|
||||||
|
done = asyncio.get_running_loop().create_future()
|
||||||
|
device.on(
|
||||||
|
'connection',
|
||||||
|
lambda connection: asyncio.create_task(
|
||||||
|
show_device_info(Peer(connection), done)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await device.start_advertising(auto_restart=True)
|
||||||
|
|
||||||
|
print(color('### Waiting for connection...', 'blue'))
|
||||||
|
await done
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@click.command()
|
||||||
|
@click.option('--device-config', help='Device configuration', type=click.Path())
|
||||||
|
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
|
||||||
|
@click.argument('transport')
|
||||||
|
@click.argument('address-or-name', required=False)
|
||||||
|
def main(device_config, encrypt, transport, address_or_name):
|
||||||
|
"""
|
||||||
|
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
|
||||||
|
wait for an incoming connection.
|
||||||
|
"""
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -75,11 +75,15 @@ async def async_main(device_config, encrypt, transport, address_or_name):
|
|||||||
|
|
||||||
if address_or_name:
|
if address_or_name:
|
||||||
# Connect to the target peer
|
# Connect to the target peer
|
||||||
|
print(color('>>> Connecting...', 'green'))
|
||||||
connection = await device.connect(address_or_name)
|
connection = await device.connect(address_or_name)
|
||||||
|
print(color('>>> Connected', 'green'))
|
||||||
|
|
||||||
# Encrypt the connection if required
|
# Encrypt the connection if required
|
||||||
if encrypt:
|
if encrypt:
|
||||||
|
print(color('+++ Encrypting connection...', 'blue'))
|
||||||
await connection.encrypt()
|
await connection.encrypt()
|
||||||
|
print(color('+++ Encryption established', 'blue'))
|
||||||
|
|
||||||
await dump_gatt_db(Peer(connection), None)
|
await dump_gatt_db(Peer(connection), None)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ import ctypes
|
|||||||
import wasmtime
|
import wasmtime
|
||||||
import wasmtime.loader
|
import wasmtime.loader
|
||||||
import liblc3 # type: ignore
|
import liblc3 # type: ignore
|
||||||
import logging
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import aiohttp.web
|
import aiohttp.web
|
||||||
@@ -43,7 +42,7 @@ from bumble.core import AdvertisingData
|
|||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
||||||
from bumble.transport import open_transport
|
from bumble.transport import open_transport
|
||||||
from bumble.profiles import bap
|
from bumble.profiles import ascs, bap, pacs
|
||||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -57,8 +56,8 @@ logger = logging.getLogger(__name__)
|
|||||||
DEFAULT_UI_PORT = 7654
|
DEFAULT_UI_PORT = 7654
|
||||||
|
|
||||||
|
|
||||||
def _sink_pac_record() -> bap.PacRecord:
|
def _sink_pac_record() -> pacs.PacRecord:
|
||||||
return bap.PacRecord(
|
return pacs.PacRecord(
|
||||||
coding_format=CodingFormat(CodecID.LC3),
|
coding_format=CodingFormat(CodecID.LC3),
|
||||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||||
supported_sampling_frequencies=(
|
supported_sampling_frequencies=(
|
||||||
@@ -79,8 +78,8 @@ def _sink_pac_record() -> bap.PacRecord:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _source_pac_record() -> bap.PacRecord:
|
def _source_pac_record() -> pacs.PacRecord:
|
||||||
return bap.PacRecord(
|
return pacs.PacRecord(
|
||||||
coding_format=CodingFormat(CodecID.LC3),
|
coding_format=CodingFormat(CodecID.LC3),
|
||||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||||
supported_sampling_frequencies=(
|
supported_sampling_frequencies=(
|
||||||
@@ -447,7 +446,7 @@ class Speaker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.device.add_service(
|
self.device.add_service(
|
||||||
bap.PublishedAudioCapabilitiesService(
|
pacs.PublishedAudioCapabilitiesService(
|
||||||
supported_source_context=bap.ContextType(0xFFFF),
|
supported_source_context=bap.ContextType(0xFFFF),
|
||||||
available_source_context=bap.ContextType(0xFFFF),
|
available_source_context=bap.ContextType(0xFFFF),
|
||||||
supported_sink_context=bap.ContextType(0xFFFF), # All context types
|
supported_sink_context=bap.ContextType(0xFFFF), # All context types
|
||||||
@@ -461,10 +460,10 @@ class Speaker:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
ascs = bap.AudioStreamControlService(
|
ascs_service = ascs.AudioStreamControlService(
|
||||||
self.device, sink_ase_id=[1], source_ase_id=[2]
|
self.device, sink_ase_id=[1], source_ase_id=[2]
|
||||||
)
|
)
|
||||||
self.device.add_service(ascs)
|
self.device.add_service(ascs_service)
|
||||||
|
|
||||||
advertising_data = bytes(
|
advertising_data = bytes(
|
||||||
AdvertisingData(
|
AdvertisingData(
|
||||||
@@ -479,13 +478,13 @@ class Speaker:
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||||
bytes(bap.PublishedAudioCapabilitiesService.UUID),
|
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
) + bytes(bap.UnicastServerAdvertisingData())
|
) + bytes(bap.UnicastServerAdvertisingData())
|
||||||
|
|
||||||
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
|
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
|
||||||
codec_config = ase.codec_specific_configuration
|
codec_config = ase.codec_specific_configuration
|
||||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||||
pcm = decode(
|
pcm = decode(
|
||||||
@@ -495,12 +494,12 @@ class Speaker:
|
|||||||
)
|
)
|
||||||
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
||||||
|
|
||||||
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
|
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
|
||||||
if ase.state == bap.AseStateMachine.State.STREAMING:
|
if ase.state == ascs.AseStateMachine.State.STREAMING:
|
||||||
codec_config = ase.codec_specific_configuration
|
codec_config = ase.codec_specific_configuration
|
||||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||||
assert ase.cis_link
|
assert ase.cis_link
|
||||||
if ase.role == bap.AudioRole.SOURCE:
|
if ase.role == ascs.AudioRole.SOURCE:
|
||||||
ase.cis_link.abort_on(
|
ase.cis_link.abort_on(
|
||||||
'disconnection',
|
'disconnection',
|
||||||
lc3_source_task(
|
lc3_source_task(
|
||||||
@@ -516,10 +515,10 @@ class Speaker:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
||||||
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
|
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
|
||||||
codec_config = ase.codec_specific_configuration
|
codec_config = ase.codec_specific_configuration
|
||||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||||
if ase.role == bap.AudioRole.SOURCE:
|
if ase.role == ascs.AudioRole.SOURCE:
|
||||||
setup_encoders(
|
setup_encoders(
|
||||||
codec_config.sampling_frequency.hz,
|
codec_config.sampling_frequency.hz,
|
||||||
codec_config.frame_duration.us,
|
codec_config.frame_duration.us,
|
||||||
@@ -532,7 +531,7 @@ class Speaker:
|
|||||||
codec_config.audio_channel_allocation.channel_count,
|
codec_config.audio_channel_allocation.channel_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ase in ascs.ase_state_machines.values():
|
for ase in ascs_service.ase_state_machines.values():
|
||||||
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
|
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
|
||||||
|
|
||||||
await self.device.power_on()
|
await self.device.power_on()
|
||||||
|
|||||||
18
bumble/at.py
18
bumble/at.py
@@ -14,13 +14,19 @@
|
|||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
|
|
||||||
|
|
||||||
|
class AtParsingError(core.InvalidPacketError):
|
||||||
|
"""Error raised when parsing AT commands fails."""
|
||||||
|
|
||||||
|
|
||||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||||
"""Split input parameters into tokens.
|
"""Split input parameters into tokens.
|
||||||
Removes space characters outside of double quote blocks:
|
Removes space characters outside of double quote blocks:
|
||||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||||
are ignored [..], unless they are embedded in numeric or string constants"
|
are ignored [..], unless they are embedded in numeric or string constants"
|
||||||
Raises ValueError in case of invalid input string."""
|
Raises AtParsingError in case of invalid input string."""
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
in_quotes = False
|
in_quotes = False
|
||||||
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
|||||||
token = bytearray()
|
token = bytearray()
|
||||||
elif char == b'(':
|
elif char == b'(':
|
||||||
if len(token) > 0:
|
if len(token) > 0:
|
||||||
raise ValueError("open_paren following regular character")
|
raise AtParsingError("open_paren following regular character")
|
||||||
tokens.append(char)
|
tokens.append(char)
|
||||||
elif char == b'"':
|
elif char == b'"':
|
||||||
if len(token) > 0:
|
if len(token) > 0:
|
||||||
raise ValueError("quote following regular character")
|
raise AtParsingError("quote following regular character")
|
||||||
in_quotes = True
|
in_quotes = True
|
||||||
token.extend(char)
|
token.extend(char)
|
||||||
else:
|
else:
|
||||||
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
|||||||
|
|
||||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||||
"""Parse the parameters using the comma and parenthesis separators.
|
"""Parse the parameters using the comma and parenthesis separators.
|
||||||
Raises ValueError in case of invalid input string."""
|
Raises AtParsingError in case of invalid input string."""
|
||||||
|
|
||||||
tokens = tokenize_parameters(buffer)
|
tokens = tokenize_parameters(buffer)
|
||||||
accumulator: List[list] = [[]]
|
accumulator: List[list] = [[]]
|
||||||
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
|||||||
accumulator.append([])
|
accumulator.append([])
|
||||||
elif token == b')':
|
elif token == b')':
|
||||||
if len(accumulator) < 2:
|
if len(accumulator) < 2:
|
||||||
raise ValueError("close_paren without matching open_paren")
|
raise AtParsingError("close_paren without matching open_paren")
|
||||||
accumulator[-1].append(current)
|
accumulator[-1].append(current)
|
||||||
current = accumulator.pop()
|
current = accumulator.pop()
|
||||||
else:
|
else:
|
||||||
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
|||||||
|
|
||||||
accumulator[-1].append(current)
|
accumulator[-1].append(current)
|
||||||
if len(accumulator) > 1:
|
if len(accumulator) > 1:
|
||||||
raise ValueError("missing close_paren")
|
raise AtParsingError("missing close_paren")
|
||||||
return accumulator[0]
|
return accumulator[0]
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import enum
|
|||||||
import struct
|
import struct
|
||||||
from typing import Dict, Type, Union, Tuple
|
from typing import Dict, Type, Union, Tuple
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
from bumble.utils import OpenIntEnum
|
from bumble.utils import OpenIntEnum
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +89,9 @@ class Frame:
|
|||||||
short_name = subclass.__name__.replace("ResponseFrame", "")
|
short_name = subclass.__name__.replace("ResponseFrame", "")
|
||||||
category_class = ResponseFrame
|
category_class = ResponseFrame
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"invalid subclass name {subclass.__name__}")
|
raise core.InvalidArgumentError(
|
||||||
|
f"invalid subclass name {subclass.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
uppercase_indexes = [
|
uppercase_indexes = [
|
||||||
i for i in range(len(short_name)) if short_name[i].isupper()
|
i for i in range(len(short_name)) if short_name[i].isupper()
|
||||||
@@ -106,7 +109,7 @@ class Frame:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_bytes(data: bytes) -> Frame:
|
def from_bytes(data: bytes) -> Frame:
|
||||||
if data[0] >> 4 != 0:
|
if data[0] >> 4 != 0:
|
||||||
raise ValueError("first 4 bits must be 0s")
|
raise core.InvalidPacketError("first 4 bits must be 0s")
|
||||||
|
|
||||||
ctype_or_response = data[0] & 0xF
|
ctype_or_response = data[0] & 0xF
|
||||||
subunit_type = Frame.SubunitType(data[1] >> 3)
|
subunit_type = Frame.SubunitType(data[1] >> 3)
|
||||||
@@ -122,7 +125,7 @@ class Frame:
|
|||||||
# Extended to the next byte
|
# Extended to the next byte
|
||||||
extension = data[2]
|
extension = data[2]
|
||||||
if extension == 0:
|
if extension == 0:
|
||||||
raise ValueError("extended subunit ID value reserved")
|
raise core.InvalidPacketError("extended subunit ID value reserved")
|
||||||
if extension == 0xFF:
|
if extension == 0xFF:
|
||||||
subunit_id = 5 + 254 + data[3]
|
subunit_id = 5 + 254 + data[3]
|
||||||
opcode_offset = 4
|
opcode_offset = 4
|
||||||
@@ -131,7 +134,7 @@ class Frame:
|
|||||||
opcode_offset = 3
|
opcode_offset = 3
|
||||||
|
|
||||||
elif subunit_id == 6:
|
elif subunit_id == 6:
|
||||||
raise ValueError("reserved subunit ID")
|
raise core.InvalidPacketError("reserved subunit ID")
|
||||||
|
|
||||||
opcode = Frame.OperationCode(data[opcode_offset])
|
opcode = Frame.OperationCode(data[opcode_offset])
|
||||||
operands = data[opcode_offset + 1 :]
|
operands = data[opcode_offset + 1 :]
|
||||||
@@ -448,7 +451,7 @@ class PassThroughFrame:
|
|||||||
operation_data: bytes,
|
operation_data: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
if len(operation_data) > 255:
|
if len(operation_data) > 255:
|
||||||
raise ValueError("operation data must be <= 255 bytes")
|
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
|
||||||
self.state_flag = state_flag
|
self.state_flag = state_flag
|
||||||
self.operation_id = operation_id
|
self.operation_id = operation_id
|
||||||
self.operation_data = operation_data
|
self.operation_data = operation_data
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional
|
|||||||
|
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble import avc
|
from bumble import avc
|
||||||
|
from bumble import core
|
||||||
from bumble import l2cap
|
from bumble import l2cap
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -275,7 +276,7 @@ class Protocol:
|
|||||||
self, pid: int, handler: Protocol.CommandHandler
|
self, pid: int, handler: Protocol.CommandHandler
|
||||||
) -> None:
|
) -> None:
|
||||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||||
raise ValueError("command handler not registered")
|
raise core.InvalidArgumentError("command handler not registered")
|
||||||
del self.command_handlers[pid]
|
del self.command_handlers[pid]
|
||||||
|
|
||||||
def register_response_handler(
|
def register_response_handler(
|
||||||
@@ -287,5 +288,5 @@ class Protocol:
|
|||||||
self, pid: int, handler: Protocol.ResponseHandler
|
self, pid: int, handler: Protocol.ResponseHandler
|
||||||
) -> None:
|
) -> None:
|
||||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||||
raise ValueError("response handler not registered")
|
raise core.InvalidArgumentError("response handler not registered")
|
||||||
del self.response_handlers[pid]
|
del self.response_handlers[pid]
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from .core import (
|
|||||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||||
InvalidStateError,
|
InvalidStateError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
|
InvalidArgumentError,
|
||||||
name_or_number,
|
name_or_number,
|
||||||
)
|
)
|
||||||
from .a2dp import (
|
from .a2dp import (
|
||||||
@@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
|
|||||||
signal_identifier_str = name[:-7]
|
signal_identifier_str = name[:-7]
|
||||||
message_type = Message.MessageType.RESPONSE_REJECT
|
message_type = Message.MessageType.RESPONSE_REJECT
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid class name')
|
raise InvalidArgumentError('invalid class name')
|
||||||
|
|
||||||
subclass.message_type = message_type
|
subclass.message_type = message_type
|
||||||
|
|
||||||
@@ -2162,6 +2163,9 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
|
|||||||
def on_abort_command(self):
|
def on_abort_command(self):
|
||||||
self.emit('abort')
|
self.emit('abort')
|
||||||
|
|
||||||
|
def on_delayreport_command(self, delay: int):
|
||||||
|
self.emit('delay_report', delay)
|
||||||
|
|
||||||
def on_rtp_channel_open(self):
|
def on_rtp_channel_open(self):
|
||||||
self.emit('rtp_channel_open')
|
self.emit('rtp_channel_open')
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from bumble.sdp import (
|
|||||||
)
|
)
|
||||||
from bumble.utils import AsyncRunner, OpenIntEnum
|
from bumble.utils import AsyncRunner, OpenIntEnum
|
||||||
from bumble.core import (
|
from bumble.core import (
|
||||||
|
InvalidArgumentError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
BT_L2CAP_PROTOCOL_ID,
|
BT_L2CAP_PROTOCOL_ID,
|
||||||
BT_AVCTP_PROTOCOL_ID,
|
BT_AVCTP_PROTOCOL_ID,
|
||||||
@@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter):
|
|||||||
def notify_track_changed(self, identifier: bytes) -> None:
|
def notify_track_changed(self, identifier: bytes) -> None:
|
||||||
"""Notify the connected peer of a Track change."""
|
"""Notify the connected peer of a Track change."""
|
||||||
if len(identifier) != 8:
|
if len(identifier) != 8:
|
||||||
raise ValueError("identifier must be 8 bytes")
|
raise InvalidArgumentError("identifier must be 8 bytes")
|
||||||
self.notify_event(TrackChangedEvent(identifier))
|
self.notify_event(TrackChangedEvent(identifier))
|
||||||
|
|
||||||
def notify_playback_position_changed(self, position: int) -> None:
|
def notify_playback_position_changed(self, position: int) -> None:
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class BitReader:
|
class BitReader:
|
||||||
@@ -40,7 +42,7 @@ class BitReader:
|
|||||||
""" "Read up to 32 bits."""
|
""" "Read up to 32 bits."""
|
||||||
|
|
||||||
if bits > 32:
|
if bits > 32:
|
||||||
raise ValueError('maximum read size is 32')
|
raise core.InvalidArgumentError('maximum read size is 32')
|
||||||
|
|
||||||
if self.bits_cached >= bits:
|
if self.bits_cached >= bits:
|
||||||
# We have enough bits.
|
# We have enough bits.
|
||||||
@@ -53,7 +55,7 @@ class BitReader:
|
|||||||
feed_size = len(feed_bytes)
|
feed_size = len(feed_bytes)
|
||||||
feed_int = int.from_bytes(feed_bytes, byteorder='big')
|
feed_int = int.from_bytes(feed_bytes, byteorder='big')
|
||||||
if 8 * feed_size + self.bits_cached < bits:
|
if 8 * feed_size + self.bits_cached < bits:
|
||||||
raise ValueError('trying to read past the data')
|
raise core.InvalidArgumentError('trying to read past the data')
|
||||||
self.byte_position += feed_size
|
self.byte_position += feed_size
|
||||||
|
|
||||||
# Combine the new cache and the old cache
|
# Combine the new cache and the old cache
|
||||||
@@ -68,7 +70,7 @@ class BitReader:
|
|||||||
|
|
||||||
def read_bytes(self, count: int):
|
def read_bytes(self, count: int):
|
||||||
if self.bit_position + 8 * count > 8 * len(self.data):
|
if self.bit_position + 8 * count > 8 * len(self.data):
|
||||||
raise ValueError('not enough data')
|
raise core.InvalidArgumentError('not enough data')
|
||||||
|
|
||||||
if self.bit_position % 8:
|
if self.bit_position % 8:
|
||||||
# Not byte aligned
|
# Not byte aligned
|
||||||
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def program_config_element(reader: BitReader):
|
def program_config_element(reader: BitReader):
|
||||||
raise ValueError('program_config_element not supported')
|
raise core.InvalidPacketError('program_config_element not supported')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GASpecificConfig:
|
class GASpecificConfig:
|
||||||
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
|
|||||||
aac_spectral_data_resilience_flags = reader.read(1)
|
aac_spectral_data_resilience_flags = reader.read(1)
|
||||||
extension_flag_3 = reader.read(1)
|
extension_flag_3 = reader.read(1)
|
||||||
if extension_flag_3 == 1:
|
if extension_flag_3 == 1:
|
||||||
raise ValueError('extensionFlag3 == 1 not supported')
|
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def audio_object_type(reader: BitReader):
|
def audio_object_type(reader: BitReader):
|
||||||
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
|
|||||||
reader, self.channel_configuration, self.audio_object_type
|
reader, self.channel_configuration, self.audio_object_type
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise core.InvalidPacketError(
|
||||||
f'audioObjectType {self.audio_object_type} not supported'
|
f'audioObjectType {self.audio_object_type} not supported'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
|
|||||||
else:
|
else:
|
||||||
audio_mux_version_a = 0
|
audio_mux_version_a = 0
|
||||||
if audio_mux_version_a != 0:
|
if audio_mux_version_a != 0:
|
||||||
raise ValueError('audioMuxVersionA != 0 not supported')
|
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
|
||||||
if audio_mux_version == 1:
|
if audio_mux_version == 1:
|
||||||
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
||||||
stream_cnt = 0
|
stream_cnt = 0
|
||||||
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
|
|||||||
num_sub_frames = reader.read(6)
|
num_sub_frames = reader.read(6)
|
||||||
num_program = reader.read(4)
|
num_program = reader.read(4)
|
||||||
if num_program != 0:
|
if num_program != 0:
|
||||||
raise ValueError('num_program != 0 not supported')
|
raise core.InvalidPacketError('num_program != 0 not supported')
|
||||||
num_layer = reader.read(3)
|
num_layer = reader.read(3)
|
||||||
if num_layer != 0:
|
if num_layer != 0:
|
||||||
raise ValueError('num_layer != 0 not supported')
|
raise core.InvalidPacketError('num_layer != 0 not supported')
|
||||||
if audio_mux_version == 0:
|
if audio_mux_version == 0:
|
||||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||||
reader
|
reader
|
||||||
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
|
|||||||
)
|
)
|
||||||
audio_specific_config_len = reader.bit_position - marker
|
audio_specific_config_len = reader.bit_position - marker
|
||||||
if asc_len < audio_specific_config_len:
|
if asc_len < audio_specific_config_len:
|
||||||
raise ValueError('audio_specific_config_len > asc_len')
|
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
|
||||||
asc_len -= audio_specific_config_len
|
asc_len -= audio_specific_config_len
|
||||||
reader.skip(asc_len)
|
reader.skip(asc_len)
|
||||||
frame_length_type = reader.read(3)
|
frame_length_type = reader.read(3)
|
||||||
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
|
|||||||
elif frame_length_type == 1:
|
elif frame_length_type == 1:
|
||||||
frame_length = reader.read(9)
|
frame_length = reader.read(9)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'frame_length_type {frame_length_type} not supported')
|
raise core.InvalidPacketError(
|
||||||
|
f'frame_length_type {frame_length_type} not supported'
|
||||||
|
)
|
||||||
|
|
||||||
self.other_data_present = reader.read(1)
|
self.other_data_present = reader.read(1)
|
||||||
if self.other_data_present:
|
if self.other_data_present:
|
||||||
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
|
|||||||
|
|
||||||
def __init__(self, reader: BitReader, mux_config_present: int):
|
def __init__(self, reader: BitReader, mux_config_present: int):
|
||||||
if mux_config_present == 0:
|
if mux_config_present == 0:
|
||||||
raise ValueError('muxConfigPresent == 0 not supported')
|
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
|
||||||
|
|
||||||
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
||||||
use_same_stream_mux = reader.read(1)
|
use_same_stream_mux = reader.read(1)
|
||||||
if use_same_stream_mux:
|
if use_same_stream_mux:
|
||||||
raise ValueError('useSameStreamMux == 1 not supported')
|
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
|
||||||
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
||||||
|
|
||||||
# We only support:
|
# We only support:
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ from functools import partial
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class ColorError(ValueError):
|
||||||
|
"""Error raised when a color spec is invalid."""
|
||||||
|
|
||||||
|
|
||||||
# ANSI color names. There is also a "default"
|
# ANSI color names. There is also a "default"
|
||||||
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
|
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
|
||||||
|
|
||||||
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
|
|||||||
elif isinstance(spec, int) and 0 <= spec <= 255:
|
elif isinstance(spec, int) and 0 <= spec <= 255:
|
||||||
return _join(base + 8, 5, spec)
|
return _join(base + 8, 5, spec)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid color spec "%s"' % spec)
|
raise ColorError('Invalid color spec "%s"' % spec)
|
||||||
|
|
||||||
|
|
||||||
def color(
|
def color(
|
||||||
@@ -72,7 +76,7 @@ def color(
|
|||||||
if style_part in STYLES:
|
if style_part in STYLES:
|
||||||
codes.append(STYLES.index(style_part))
|
codes.append(STYLES.index(style_part))
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid style "%s"' % style_part)
|
raise ColorError('Invalid style "%s"' % style_part)
|
||||||
|
|
||||||
if codes:
|
if codes:
|
||||||
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
|
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
|
||||||
|
|||||||
@@ -79,7 +79,13 @@ def get_dict_key_by_value(dictionary, value):
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Exceptions
|
# Exceptions
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class BaseError(Exception):
|
|
||||||
|
|
||||||
|
class BaseBumbleError(Exception):
|
||||||
|
"""Base Error raised by Bumble."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseError(BaseBumbleError):
|
||||||
"""Base class for errors with an error code, error name and namespace"""
|
"""Base class for errors with an error code, error name and namespace"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -118,18 +124,42 @@ class ProtocolError(BaseError):
|
|||||||
"""Protocol Error"""
|
"""Protocol Error"""
|
||||||
|
|
||||||
|
|
||||||
class TimeoutError(Exception): # pylint: disable=redefined-builtin
|
class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin
|
||||||
"""Timeout Error"""
|
"""Timeout Error"""
|
||||||
|
|
||||||
|
|
||||||
class CommandTimeoutError(Exception):
|
class CommandTimeoutError(BaseBumbleError):
|
||||||
"""Command Timeout Error"""
|
"""Command Timeout Error"""
|
||||||
|
|
||||||
|
|
||||||
class InvalidStateError(Exception):
|
class InvalidStateError(BaseBumbleError):
|
||||||
"""Invalid State Error"""
|
"""Invalid State Error"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidArgumentError(BaseBumbleError, ValueError):
|
||||||
|
"""Invalid Argument Error"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPacketError(BaseBumbleError, ValueError):
|
||||||
|
"""Invalid Packet Error"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOperationError(BaseBumbleError, RuntimeError):
|
||||||
|
"""Invalid Operation Error"""
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedError(BaseBumbleError, RuntimeError):
|
||||||
|
"""Not Supported"""
|
||||||
|
|
||||||
|
|
||||||
|
class OutOfResourcesError(BaseBumbleError, RuntimeError):
|
||||||
|
"""Out of Resources Error"""
|
||||||
|
|
||||||
|
|
||||||
|
class UnreachableError(BaseBumbleError):
|
||||||
|
"""The code path raising this error should be unreachable."""
|
||||||
|
|
||||||
|
|
||||||
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
|
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
|
||||||
"""Connection Error"""
|
"""Connection Error"""
|
||||||
|
|
||||||
@@ -188,12 +218,12 @@ class UUID:
|
|||||||
or uuid_str_or_int[18] != '-'
|
or uuid_str_or_int[18] != '-'
|
||||||
or uuid_str_or_int[23] != '-'
|
or uuid_str_or_int[23] != '-'
|
||||||
):
|
):
|
||||||
raise ValueError('invalid UUID format')
|
raise InvalidArgumentError('invalid UUID format')
|
||||||
uuid_str = uuid_str_or_int.replace('-', '')
|
uuid_str = uuid_str_or_int.replace('-', '')
|
||||||
else:
|
else:
|
||||||
uuid_str = uuid_str_or_int
|
uuid_str = uuid_str_or_int
|
||||||
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
|
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
|
||||||
raise ValueError(f"invalid UUID format: {uuid_str}")
|
raise InvalidArgumentError(f"invalid UUID format: {uuid_str}")
|
||||||
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@@ -218,7 +248,7 @@ class UUID:
|
|||||||
|
|
||||||
return self.register()
|
return self.register()
|
||||||
|
|
||||||
raise ValueError('only 2, 4 and 16 bytes are allowed')
|
raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
|
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
|
||||||
|
|||||||
493
bumble/device.py
493
bumble/device.py
@@ -27,6 +27,7 @@ import copy
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
@@ -50,6 +51,7 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
|
from bumble import hci
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
|
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
|
||||||
from .gatt import Characteristic, Descriptor, Service
|
from .gatt import Characteristic, Descriptor, Service
|
||||||
@@ -111,6 +113,7 @@ from .hci import (
|
|||||||
HCI_LE_Periodic_Advertising_Create_Sync_Command,
|
HCI_LE_Periodic_Advertising_Create_Sync_Command,
|
||||||
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command,
|
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command,
|
||||||
HCI_LE_Periodic_Advertising_Report_Event,
|
HCI_LE_Periodic_Advertising_Report_Event,
|
||||||
|
HCI_LE_Periodic_Advertising_Sync_Transfer_Command,
|
||||||
HCI_LE_Periodic_Advertising_Terminate_Sync_Command,
|
HCI_LE_Periodic_Advertising_Terminate_Sync_Command,
|
||||||
HCI_LE_Enable_Encryption_Command,
|
HCI_LE_Enable_Encryption_Command,
|
||||||
HCI_LE_Extended_Advertising_Report_Event,
|
HCI_LE_Extended_Advertising_Report_Event,
|
||||||
@@ -167,21 +170,29 @@ from .hci import (
|
|||||||
OwnAddressType,
|
OwnAddressType,
|
||||||
LeFeature,
|
LeFeature,
|
||||||
LeFeatureMask,
|
LeFeatureMask,
|
||||||
|
LmpFeatureMask,
|
||||||
Phy,
|
Phy,
|
||||||
phy_list_to_bits,
|
phy_list_to_bits,
|
||||||
)
|
)
|
||||||
from .host import Host
|
from .host import Host
|
||||||
from .gap import GenericAccessService
|
from .profiles.gap import GenericAccessService
|
||||||
from .core import (
|
from .core import (
|
||||||
BT_BR_EDR_TRANSPORT,
|
BT_BR_EDR_TRANSPORT,
|
||||||
BT_CENTRAL_ROLE,
|
BT_CENTRAL_ROLE,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
BT_PERIPHERAL_ROLE,
|
BT_PERIPHERAL_ROLE,
|
||||||
AdvertisingData,
|
AdvertisingData,
|
||||||
|
BaseBumbleError,
|
||||||
ConnectionParameterUpdateError,
|
ConnectionParameterUpdateError,
|
||||||
CommandTimeoutError,
|
CommandTimeoutError,
|
||||||
|
ConnectionParameters,
|
||||||
ConnectionPHY,
|
ConnectionPHY,
|
||||||
|
InvalidArgumentError,
|
||||||
|
InvalidOperationError,
|
||||||
InvalidStateError,
|
InvalidStateError,
|
||||||
|
NotSupportedError,
|
||||||
|
OutOfResourcesError,
|
||||||
|
UnreachableError,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AsyncRunner,
|
AsyncRunner,
|
||||||
@@ -196,13 +207,13 @@ from .keys import (
|
|||||||
KeyStore,
|
KeyStore,
|
||||||
PairingKeys,
|
PairingKeys,
|
||||||
)
|
)
|
||||||
from .pairing import PairingConfig
|
from bumble import pairing
|
||||||
from . import gatt_client
|
from bumble import gatt_client
|
||||||
from . import gatt_server
|
from bumble import gatt_server
|
||||||
from . import smp
|
from bumble import smp
|
||||||
from . import sdp
|
from bumble import sdp
|
||||||
from . import l2cap
|
from bumble import l2cap
|
||||||
from . import core
|
from bumble import core
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .transport.common import TransportSource, TransportSink
|
from .transport.common import TransportSource, TransportSink
|
||||||
@@ -253,8 +264,9 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
|
|||||||
DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
|
DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
|
||||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
|
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
|
||||||
)
|
)
|
||||||
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
|
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
|
||||||
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
|
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
|
||||||
|
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
@@ -266,6 +278,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Classes
|
# Classes
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
class ObjectLookupError(BaseBumbleError):
|
||||||
|
"""Error raised when failed to lookup an object."""
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -958,20 +972,25 @@ class PeriodicAdvertisingSync(EventEmitter):
|
|||||||
response = await self.device.send_command(
|
response = await self.device.send_command(
|
||||||
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(),
|
HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(),
|
||||||
)
|
)
|
||||||
if response.status == HCI_SUCCESS:
|
if response.return_parameters == HCI_SUCCESS:
|
||||||
if self in self.device.periodic_advertising_syncs:
|
if self in self.device.periodic_advertising_syncs:
|
||||||
self.device.periodic_advertising_syncs.remove(self)
|
self.device.periodic_advertising_syncs.remove(self)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST):
|
if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST):
|
||||||
self.state = self.State.TERMINATED
|
self.state = self.State.TERMINATED
|
||||||
await self.device.send_command(
|
if self.sync_handle is not None:
|
||||||
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
|
await self.device.send_command(
|
||||||
sync_handle=self.sync_handle
|
HCI_LE_Periodic_Advertising_Terminate_Sync_Command(
|
||||||
|
sync_handle=self.sync_handle
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.device.periodic_advertising_syncs.remove(self)
|
self.device.periodic_advertising_syncs.remove(self)
|
||||||
|
|
||||||
|
async def transfer(self, connection: Connection, service_data: int = 0) -> None:
|
||||||
|
if self.sync_handle is not None:
|
||||||
|
await connection.transfer_periodic_sync(self.sync_handle, service_data)
|
||||||
|
|
||||||
def on_establishment(
|
def on_establishment(
|
||||||
self,
|
self,
|
||||||
status,
|
status,
|
||||||
@@ -1133,6 +1152,15 @@ class Peer:
|
|||||||
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
|
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
|
||||||
return await self.gatt_client.discover_attributes()
|
return await self.gatt_client.discover_attributes()
|
||||||
|
|
||||||
|
async def discover_all(self):
|
||||||
|
await self.discover_services()
|
||||||
|
for service in self.services:
|
||||||
|
await self.discover_characteristics(service=service)
|
||||||
|
|
||||||
|
for service in self.services:
|
||||||
|
for characteristic in service.characteristics:
|
||||||
|
await self.discover_descriptors(characteristic=characteristic)
|
||||||
|
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self,
|
self,
|
||||||
characteristic: gatt_client.CharacteristicProxy,
|
characteristic: gatt_client.CharacteristicProxy,
|
||||||
@@ -1172,12 +1200,29 @@ class Peer:
|
|||||||
return self.gatt_client.get_services_by_uuid(uuid)
|
return self.gatt_client.get_services_by_uuid(uuid)
|
||||||
|
|
||||||
def get_characteristics_by_uuid(
|
def get_characteristics_by_uuid(
|
||||||
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
|
self,
|
||||||
|
uuid: core.UUID,
|
||||||
|
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
|
||||||
) -> List[gatt_client.CharacteristicProxy]:
|
) -> List[gatt_client.CharacteristicProxy]:
|
||||||
|
if isinstance(service, core.UUID):
|
||||||
|
return list(
|
||||||
|
itertools.chain(
|
||||||
|
*[
|
||||||
|
self.get_characteristics_by_uuid(uuid, s)
|
||||||
|
for s in self.get_services_by_uuid(service)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
|
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
|
||||||
|
|
||||||
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
|
def create_service_proxy(
|
||||||
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
|
self, proxy_class: Type[_PROXY_CLASS]
|
||||||
|
) -> Optional[_PROXY_CLASS]:
|
||||||
|
if proxy := proxy_class.from_client(self.gatt_client):
|
||||||
|
return cast(_PROXY_CLASS, proxy)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def discover_service_and_create_proxy(
|
async def discover_service_and_create_proxy(
|
||||||
self, proxy_class: Type[_PROXY_CLASS]
|
self, proxy_class: Type[_PROXY_CLASS]
|
||||||
@@ -1274,6 +1319,7 @@ class Connection(CompositeEventEmitter):
|
|||||||
handle: int
|
handle: int
|
||||||
transport: int
|
transport: int
|
||||||
self_address: Address
|
self_address: Address
|
||||||
|
self_resolvable_address: Optional[Address]
|
||||||
peer_address: Address
|
peer_address: Address
|
||||||
peer_resolvable_address: Optional[Address]
|
peer_resolvable_address: Optional[Address]
|
||||||
peer_le_features: Optional[LeFeatureMask]
|
peer_le_features: Optional[LeFeatureMask]
|
||||||
@@ -1321,6 +1367,7 @@ class Connection(CompositeEventEmitter):
|
|||||||
handle,
|
handle,
|
||||||
transport,
|
transport,
|
||||||
self_address,
|
self_address,
|
||||||
|
self_resolvable_address,
|
||||||
peer_address,
|
peer_address,
|
||||||
peer_resolvable_address,
|
peer_resolvable_address,
|
||||||
role,
|
role,
|
||||||
@@ -1332,6 +1379,7 @@ class Connection(CompositeEventEmitter):
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.self_address = self_address
|
self.self_address = self_address
|
||||||
|
self.self_resolvable_address = self_resolvable_address
|
||||||
self.peer_address = peer_address
|
self.peer_address = peer_address
|
||||||
self.peer_resolvable_address = peer_resolvable_address
|
self.peer_resolvable_address = peer_resolvable_address
|
||||||
self.peer_name = None # Classic only
|
self.peer_name = None # Classic only
|
||||||
@@ -1365,6 +1413,7 @@ class Connection(CompositeEventEmitter):
|
|||||||
None,
|
None,
|
||||||
BT_BR_EDR_TRANSPORT,
|
BT_BR_EDR_TRANSPORT,
|
||||||
device.public_address,
|
device.public_address,
|
||||||
|
None,
|
||||||
peer_address,
|
peer_address,
|
||||||
None,
|
None,
|
||||||
role,
|
role,
|
||||||
@@ -1458,11 +1507,9 @@ class Connection(CompositeEventEmitter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
|
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
|
||||||
except asyncio.TimeoutError:
|
finally:
|
||||||
pass
|
self.remove_listener('disconnection', abort.set_result)
|
||||||
|
self.remove_listener('disconnection_failure', abort.set_exception)
|
||||||
self.remove_listener('disconnection', abort.set_result)
|
|
||||||
self.remove_listener('disconnection_failure', abort.set_exception)
|
|
||||||
|
|
||||||
async def set_data_length(self, tx_octets, tx_time) -> None:
|
async def set_data_length(self, tx_octets, tx_time) -> None:
|
||||||
return await self.device.set_data_length(self, tx_octets, tx_time)
|
return await self.device.set_data_length(self, tx_octets, tx_time)
|
||||||
@@ -1493,6 +1540,11 @@ class Connection(CompositeEventEmitter):
|
|||||||
async def get_phy(self):
|
async def get_phy(self):
|
||||||
return await self.device.get_connection_phy(self)
|
return await self.device.get_connection_phy(self)
|
||||||
|
|
||||||
|
async def transfer_periodic_sync(
|
||||||
|
self, sync_handle: int, service_data: int = 0
|
||||||
|
) -> None:
|
||||||
|
await self.device.transfer_periodic_sync(self, sync_handle, service_data)
|
||||||
|
|
||||||
# [Classic only]
|
# [Classic only]
|
||||||
async def request_remote_name(self):
|
async def request_remote_name(self):
|
||||||
return await self.device.request_remote_name(self)
|
return await self.device.request_remote_name(self)
|
||||||
@@ -1523,7 +1575,9 @@ class Connection(CompositeEventEmitter):
|
|||||||
f'Connection(handle=0x{self.handle:04X}, '
|
f'Connection(handle=0x{self.handle:04X}, '
|
||||||
f'role={self.role_name}, '
|
f'role={self.role_name}, '
|
||||||
f'self_address={self.self_address}, '
|
f'self_address={self.self_address}, '
|
||||||
f'peer_address={self.peer_address})'
|
f'self_resolvable_address={self.self_resolvable_address}, '
|
||||||
|
f'peer_address={self.peer_address}, '
|
||||||
|
f'peer_resolvable_address={self.peer_resolvable_address})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1538,13 +1592,15 @@ class DeviceConfiguration:
|
|||||||
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||||
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||||
le_enabled: bool = True
|
le_enabled: bool = True
|
||||||
# LE host enable 2nd parameter
|
|
||||||
le_simultaneous_enabled: bool = False
|
le_simultaneous_enabled: bool = False
|
||||||
|
le_privacy_enabled: bool = False
|
||||||
|
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
|
||||||
classic_enabled: bool = False
|
classic_enabled: bool = False
|
||||||
classic_sc_enabled: bool = True
|
classic_sc_enabled: bool = True
|
||||||
classic_ssp_enabled: bool = True
|
classic_ssp_enabled: bool = True
|
||||||
classic_smp_enabled: bool = True
|
classic_smp_enabled: bool = True
|
||||||
classic_accept_any: bool = True
|
classic_accept_any: bool = True
|
||||||
|
classic_interlaced_scan_enabled: bool = True
|
||||||
connectable: bool = True
|
connectable: bool = True
|
||||||
discoverable: bool = True
|
discoverable: bool = True
|
||||||
advertising_data: bytes = bytes(
|
advertising_data: bytes = bytes(
|
||||||
@@ -1555,7 +1611,10 @@ class DeviceConfiguration:
|
|||||||
irk: bytes = bytes(16) # This really must be changed for any level of security
|
irk: bytes = bytes(16) # This really must be changed for any level of security
|
||||||
keystore: Optional[str] = None
|
keystore: Optional[str] = None
|
||||||
address_resolution_offload: bool = False
|
address_resolution_offload: bool = False
|
||||||
|
address_generation_offload: bool = False
|
||||||
cis_enabled: bool = False
|
cis_enabled: bool = False
|
||||||
|
identity_address_type: Optional[int] = None
|
||||||
|
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.gatt_services: List[Dict[str, Any]] = []
|
self.gatt_services: List[Dict[str, Any]] = []
|
||||||
@@ -1640,7 +1699,9 @@ def with_connection_from_handle(function):
|
|||||||
@functools.wraps(function)
|
@functools.wraps(function)
|
||||||
def wrapper(self, connection_handle, *args, **kwargs):
|
def wrapper(self, connection_handle, *args, **kwargs):
|
||||||
if (connection := self.lookup_connection(connection_handle)) is None:
|
if (connection := self.lookup_connection(connection_handle)) is None:
|
||||||
raise ValueError(f'no connection for handle: 0x{connection_handle:04x}')
|
raise ObjectLookupError(
|
||||||
|
f'no connection for handle: 0x{connection_handle:04x}'
|
||||||
|
)
|
||||||
return function(self, connection, *args, **kwargs)
|
return function(self, connection, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -1655,7 +1716,7 @@ def with_connection_from_address(function):
|
|||||||
for connection in self.connections.values():
|
for connection in self.connections.values():
|
||||||
if connection.peer_address == address:
|
if connection.peer_address == address:
|
||||||
return function(self, connection, *args, **kwargs)
|
return function(self, connection, *args, **kwargs)
|
||||||
raise ValueError('no connection for address')
|
raise ObjectLookupError('no connection for address')
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -1705,8 +1766,9 @@ device_host_event_handlers: List[str] = []
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Device(CompositeEventEmitter):
|
class Device(CompositeEventEmitter):
|
||||||
# Incomplete list of fields.
|
# Incomplete list of fields.
|
||||||
random_address: Address
|
random_address: Address # Random address that may change with RPA
|
||||||
public_address: Address
|
public_address: Address # Public address (obtained from the controller)
|
||||||
|
static_address: Address # Random address that can be set but does not change
|
||||||
classic_enabled: bool
|
classic_enabled: bool
|
||||||
name: str
|
name: str
|
||||||
class_of_device: int
|
class_of_device: int
|
||||||
@@ -1836,23 +1898,29 @@ class Device(CompositeEventEmitter):
|
|||||||
config = config or DeviceConfiguration()
|
config = config or DeviceConfiguration()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.public_address = Address('00:00:00:00:00:00')
|
|
||||||
self.name = config.name
|
self.name = config.name
|
||||||
|
self.public_address = Address.ANY
|
||||||
self.random_address = config.address
|
self.random_address = config.address
|
||||||
|
self.static_address = config.address
|
||||||
self.class_of_device = config.class_of_device
|
self.class_of_device = config.class_of_device
|
||||||
self.keystore = None
|
self.keystore = None
|
||||||
self.irk = config.irk
|
self.irk = config.irk
|
||||||
self.le_enabled = config.le_enabled
|
self.le_enabled = config.le_enabled
|
||||||
self.classic_enabled = config.classic_enabled
|
|
||||||
self.le_simultaneous_enabled = config.le_simultaneous_enabled
|
self.le_simultaneous_enabled = config.le_simultaneous_enabled
|
||||||
|
self.le_privacy_enabled = config.le_privacy_enabled
|
||||||
|
self.le_rpa_timeout = config.le_rpa_timeout
|
||||||
|
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
|
||||||
|
self.classic_enabled = config.classic_enabled
|
||||||
self.cis_enabled = config.cis_enabled
|
self.cis_enabled = config.cis_enabled
|
||||||
self.classic_sc_enabled = config.classic_sc_enabled
|
self.classic_sc_enabled = config.classic_sc_enabled
|
||||||
self.classic_ssp_enabled = config.classic_ssp_enabled
|
self.classic_ssp_enabled = config.classic_ssp_enabled
|
||||||
self.classic_smp_enabled = config.classic_smp_enabled
|
self.classic_smp_enabled = config.classic_smp_enabled
|
||||||
|
self.classic_interlaced_scan_enabled = config.classic_interlaced_scan_enabled
|
||||||
self.discoverable = config.discoverable
|
self.discoverable = config.discoverable
|
||||||
self.connectable = config.connectable
|
self.connectable = config.connectable
|
||||||
self.classic_accept_any = config.classic_accept_any
|
self.classic_accept_any = config.classic_accept_any
|
||||||
self.address_resolution_offload = config.address_resolution_offload
|
self.address_resolution_offload = config.address_resolution_offload
|
||||||
|
self.address_generation_offload = config.address_generation_offload
|
||||||
|
|
||||||
# Extended advertising.
|
# Extended advertising.
|
||||||
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
|
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
|
||||||
@@ -1908,10 +1976,23 @@ class Device(CompositeEventEmitter):
|
|||||||
if isinstance(address, str):
|
if isinstance(address, str):
|
||||||
address = Address(address)
|
address = Address(address)
|
||||||
self.random_address = address
|
self.random_address = address
|
||||||
|
self.static_address = address
|
||||||
|
|
||||||
# Setup SMP
|
# Setup SMP
|
||||||
self.smp_manager = smp.Manager(
|
self.smp_manager = smp.Manager(
|
||||||
self, pairing_config_factory=lambda connection: PairingConfig()
|
self,
|
||||||
|
pairing_config_factory=lambda connection: pairing.PairingConfig(
|
||||||
|
identity_address_type=(
|
||||||
|
pairing.PairingConfig.AddressType(self.config.identity_address_type)
|
||||||
|
if self.config.identity_address_type
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
delegate=pairing.PairingDelegate(
|
||||||
|
io_capability=pairing.PairingDelegate.IoCapability(
|
||||||
|
self.config.io_capability
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
||||||
@@ -2093,7 +2174,7 @@ class Device(CompositeEventEmitter):
|
|||||||
spec=spec,
|
spec=spec,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unexpected mode {spec}')
|
raise InvalidArgumentError(f'Unexpected mode {spec}')
|
||||||
|
|
||||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||||
@@ -2135,26 +2216,26 @@ class Device(CompositeEventEmitter):
|
|||||||
HCI_Write_LE_Host_Support_Command(
|
HCI_Write_LE_Host_Support_Command(
|
||||||
le_supported_host=int(self.le_enabled),
|
le_supported_host=int(self.le_enabled),
|
||||||
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.le_enabled:
|
if self.le_enabled:
|
||||||
# Set the controller address
|
# Generate a random address if not set.
|
||||||
if self.random_address == Address.ANY_RANDOM:
|
if self.static_address == Address.ANY_RANDOM:
|
||||||
# Try to use an address generated at random by the controller
|
self.static_address = Address.generate_static_address()
|
||||||
if self.host.supports_command(HCI_LE_RAND_COMMAND):
|
|
||||||
# Get 8 random bytes
|
# If LE Privacy is enabled, generate an RPA
|
||||||
response = await self.send_command(
|
if self.le_privacy_enabled:
|
||||||
HCI_LE_Rand_Command(), check_result=True
|
self.random_address = Address.generate_private_address(self.irk)
|
||||||
|
logger.info(f'Initial RPA: {self.random_address}')
|
||||||
|
if self.le_rpa_timeout > 0:
|
||||||
|
# Start a task to periodically generate a new RPA
|
||||||
|
self.le_rpa_periodic_update_task = asyncio.create_task(
|
||||||
|
self._run_rpa_periodic_update()
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# Ensure the address bytes can be a static random address
|
self.random_address = self.static_address
|
||||||
address_bytes = response.return_parameters.random_number[
|
|
||||||
:5
|
|
||||||
] + bytes([response.return_parameters.random_number[5] | 0xC0])
|
|
||||||
|
|
||||||
# Create a static random address from the random bytes
|
|
||||||
self.random_address = Address(address_bytes)
|
|
||||||
|
|
||||||
if self.random_address != Address.ANY_RANDOM:
|
if self.random_address != Address.ANY_RANDOM:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -2179,7 +2260,8 @@ class Device(CompositeEventEmitter):
|
|||||||
await self.send_command(
|
await self.send_command(
|
||||||
HCI_LE_Set_Address_Resolution_Enable_Command(
|
HCI_LE_Set_Address_Resolution_Enable_Command(
|
||||||
address_resolution_enable=1
|
address_resolution_enable=1
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cis_enabled:
|
if self.cis_enabled:
|
||||||
@@ -2187,7 +2269,8 @@ class Device(CompositeEventEmitter):
|
|||||||
HCI_LE_Set_Host_Feature_Command(
|
HCI_LE_Set_Host_Feature_Command(
|
||||||
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
|
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
|
||||||
bit_value=1,
|
bit_value=1,
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.classic_enabled:
|
if self.classic_enabled:
|
||||||
@@ -2210,6 +2293,21 @@ class Device(CompositeEventEmitter):
|
|||||||
await self.set_connectable(self.connectable)
|
await self.set_connectable(self.connectable)
|
||||||
await self.set_discoverable(self.discoverable)
|
await self.set_discoverable(self.discoverable)
|
||||||
|
|
||||||
|
if self.classic_interlaced_scan_enabled:
|
||||||
|
if self.host.supports_lmp_features(LmpFeatureMask.INTERLACED_PAGE_SCAN):
|
||||||
|
await self.send_command(
|
||||||
|
hci.HCI_Write_Page_Scan_Type_Command(page_scan_type=1),
|
||||||
|
check_result=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.host.supports_lmp_features(
|
||||||
|
LmpFeatureMask.INTERLACED_INQUIRY_SCAN
|
||||||
|
):
|
||||||
|
await self.send_command(
|
||||||
|
hci.HCI_Write_Inquiry_Scan_Type_Command(scan_type=1),
|
||||||
|
check_result=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Done
|
# Done
|
||||||
self.powered_on = True
|
self.powered_on = True
|
||||||
|
|
||||||
@@ -2218,9 +2316,45 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
async def power_off(self) -> None:
|
async def power_off(self) -> None:
|
||||||
if self.powered_on:
|
if self.powered_on:
|
||||||
|
if self.le_rpa_periodic_update_task:
|
||||||
|
self.le_rpa_periodic_update_task.cancel()
|
||||||
|
|
||||||
await self.host.flush()
|
await self.host.flush()
|
||||||
|
|
||||||
self.powered_on = False
|
self.powered_on = False
|
||||||
|
|
||||||
|
async def update_rpa(self) -> bool:
|
||||||
|
"""
|
||||||
|
Try to update the RPA.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the RPA was updated, False if it could not be updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if this is a good time to rotate the address
|
||||||
|
if self.is_advertising or self.is_scanning or self.is_le_connecting:
|
||||||
|
logger.debug('skipping RPA update')
|
||||||
|
return False
|
||||||
|
|
||||||
|
random_address = Address.generate_private_address(self.irk)
|
||||||
|
response = await self.send_command(
|
||||||
|
HCI_LE_Set_Random_Address_Command(random_address=self.random_address)
|
||||||
|
)
|
||||||
|
if response.return_parameters == HCI_SUCCESS:
|
||||||
|
logger.info(f'new RPA: {random_address}')
|
||||||
|
self.random_address = random_address
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f'failed to set RPA: {response.return_parameters}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _run_rpa_periodic_update(self) -> None:
|
||||||
|
"""Update the RPA periodically"""
|
||||||
|
while self.le_rpa_timeout != 0:
|
||||||
|
await asyncio.sleep(self.le_rpa_timeout)
|
||||||
|
if not self.update_rpa():
|
||||||
|
logger.debug("periodic RPA update failed")
|
||||||
|
|
||||||
async def refresh_resolving_list(self) -> None:
|
async def refresh_resolving_list(self) -> None:
|
||||||
assert self.keystore is not None
|
assert self.keystore is not None
|
||||||
|
|
||||||
@@ -2228,7 +2362,7 @@ class Device(CompositeEventEmitter):
|
|||||||
# Create a host-side address resolver
|
# Create a host-side address resolver
|
||||||
self.address_resolver = smp.AddressResolver(resolving_keys)
|
self.address_resolver = smp.AddressResolver(resolving_keys)
|
||||||
|
|
||||||
if self.address_resolution_offload:
|
if self.address_resolution_offload or self.address_generation_offload:
|
||||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
|
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
|
||||||
|
|
||||||
# Add an empty entry for non-directed address generation.
|
# Add an empty entry for non-directed address generation.
|
||||||
@@ -2254,7 +2388,7 @@ class Device(CompositeEventEmitter):
|
|||||||
def supports_le_features(self, feature: LeFeatureMask) -> bool:
|
def supports_le_features(self, feature: LeFeatureMask) -> bool:
|
||||||
return self.host.supports_le_features(feature)
|
return self.host.supports_le_features(feature)
|
||||||
|
|
||||||
def supports_le_phy(self, phy):
|
def supports_le_phy(self, phy: int) -> bool:
|
||||||
if phy == HCI_LE_1M_PHY:
|
if phy == HCI_LE_1M_PHY:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -2263,7 +2397,7 @@ class Device(CompositeEventEmitter):
|
|||||||
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
|
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
|
||||||
}
|
}
|
||||||
if phy not in feature_map:
|
if phy not in feature_map:
|
||||||
raise ValueError('invalid PHY')
|
raise InvalidArgumentError('invalid PHY')
|
||||||
|
|
||||||
return self.supports_le_features(feature_map[phy])
|
return self.supports_le_features(feature_map[phy])
|
||||||
|
|
||||||
@@ -2271,6 +2405,10 @@ class Device(CompositeEventEmitter):
|
|||||||
def supports_le_extended_advertising(self):
|
def supports_le_extended_advertising(self):
|
||||||
return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING)
|
return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_le_periodic_advertising(self):
|
||||||
|
return self.supports_le_features(LeFeatureMask.LE_PERIODIC_ADVERTISING)
|
||||||
|
|
||||||
async def start_advertising(
|
async def start_advertising(
|
||||||
self,
|
self,
|
||||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||||
@@ -2323,7 +2461,7 @@ class Device(CompositeEventEmitter):
|
|||||||
# Decide what peer address to use
|
# Decide what peer address to use
|
||||||
if advertising_type.is_directed:
|
if advertising_type.is_directed:
|
||||||
if target is None:
|
if target is None:
|
||||||
raise ValueError('directed advertising requires a target')
|
raise InvalidArgumentError('directed advertising requires a target')
|
||||||
peer_address = target
|
peer_address = target
|
||||||
else:
|
else:
|
||||||
peer_address = Address.ANY
|
peer_address = Address.ANY
|
||||||
@@ -2430,7 +2568,7 @@ class Device(CompositeEventEmitter):
|
|||||||
and advertising_data
|
and advertising_data
|
||||||
and scan_response_data
|
and scan_response_data
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise InvalidArgumentError(
|
||||||
"Extended advertisements can't have both data and scan \
|
"Extended advertisements can't have both data and scan \
|
||||||
response data"
|
response data"
|
||||||
)
|
)
|
||||||
@@ -2446,7 +2584,9 @@ class Device(CompositeEventEmitter):
|
|||||||
if handle not in self.extended_advertising_sets
|
if handle not in self.extended_advertising_sets
|
||||||
)
|
)
|
||||||
except StopIteration as exc:
|
except StopIteration as exc:
|
||||||
raise RuntimeError("all valid advertising handles already in use") from exc
|
raise OutOfResourcesError(
|
||||||
|
"all valid advertising handles already in use"
|
||||||
|
) from exc
|
||||||
|
|
||||||
# Use the device's random address if a random address is needed but none was
|
# Use the device's random address if a random address is needed but none was
|
||||||
# provided.
|
# provided.
|
||||||
@@ -2545,14 +2685,14 @@ class Device(CompositeEventEmitter):
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Check that the arguments are legal
|
# Check that the arguments are legal
|
||||||
if scan_interval < scan_window:
|
if scan_interval < scan_window:
|
||||||
raise ValueError('scan_interval must be >= scan_window')
|
raise InvalidArgumentError('scan_interval must be >= scan_window')
|
||||||
if (
|
if (
|
||||||
scan_interval < DEVICE_MIN_SCAN_INTERVAL
|
scan_interval < DEVICE_MIN_SCAN_INTERVAL
|
||||||
or scan_interval > DEVICE_MAX_SCAN_INTERVAL
|
or scan_interval > DEVICE_MAX_SCAN_INTERVAL
|
||||||
):
|
):
|
||||||
raise ValueError('scan_interval out of range')
|
raise InvalidArgumentError('scan_interval out of range')
|
||||||
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
|
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
|
||||||
raise ValueError('scan_interval out of range')
|
raise InvalidArgumentError('scan_interval out of range')
|
||||||
|
|
||||||
# Reset the accumulators
|
# Reset the accumulators
|
||||||
self.advertisement_accumulators = {}
|
self.advertisement_accumulators = {}
|
||||||
@@ -2580,7 +2720,7 @@ class Device(CompositeEventEmitter):
|
|||||||
scanning_phy_count += 1
|
scanning_phy_count += 1
|
||||||
|
|
||||||
if scanning_phy_count == 0:
|
if scanning_phy_count == 0:
|
||||||
raise ValueError('at least one scanning PHY must be enabled')
|
raise InvalidArgumentError('at least one scanning PHY must be enabled')
|
||||||
|
|
||||||
await self.send_command(
|
await self.send_command(
|
||||||
HCI_LE_Set_Extended_Scan_Parameters_Command(
|
HCI_LE_Set_Extended_Scan_Parameters_Command(
|
||||||
@@ -2671,6 +2811,10 @@ class Device(CompositeEventEmitter):
|
|||||||
sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT,
|
sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT,
|
||||||
filter_duplicates: bool = False,
|
filter_duplicates: bool = False,
|
||||||
) -> PeriodicAdvertisingSync:
|
) -> PeriodicAdvertisingSync:
|
||||||
|
# Check that the controller supports the feature.
|
||||||
|
if not self.supports_le_periodic_advertising:
|
||||||
|
raise NotSupportedError()
|
||||||
|
|
||||||
# Check that there isn't already an equivalent entry
|
# Check that there isn't already an equivalent entry
|
||||||
if any(
|
if any(
|
||||||
sync.advertiser_address == advertiser_address and sync.sid == sid
|
sync.advertiser_address == advertiser_address and sync.sid == sid
|
||||||
@@ -2868,23 +3012,52 @@ class Device(CompositeEventEmitter):
|
|||||||
] = None,
|
] = None,
|
||||||
own_address_type: int = OwnAddressType.RANDOM,
|
own_address_type: int = OwnAddressType.RANDOM,
|
||||||
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||||
|
always_resolve: bool = False,
|
||||||
) -> Connection:
|
) -> Connection:
|
||||||
'''
|
'''
|
||||||
Request a connection to a peer.
|
Request a connection to a peer.
|
||||||
When transport is BLE, this method cannot be called if there is already a
|
|
||||||
|
When the transport is BLE, this method cannot be called if there is already a
|
||||||
pending connection.
|
pending connection.
|
||||||
|
|
||||||
connection_parameters_preferences: (BLE only, ignored for BR/EDR)
|
Args:
|
||||||
* None: use the 1M PHY with default parameters
|
peer_address:
|
||||||
* map: each entry has a PHY as key and a ConnectionParametersPreferences
|
Address or name of the device to connect to.
|
||||||
object as value
|
If a string is passed:
|
||||||
|
If the string is an address followed by a `@` suffix, the `always_resolve`
|
||||||
|
argument is implicitly set to True, so the connection is made to the
|
||||||
|
address after resolution.
|
||||||
|
If the string is any other address, the connection is made to that
|
||||||
|
address (with or without address resolution, depending on the
|
||||||
|
`always_resolve` argument).
|
||||||
|
For any other string, a scan for devices using that string as their name
|
||||||
|
is initiated, and a connection to the first matching device's address
|
||||||
|
is made. In that case, `always_resolve` is ignored.
|
||||||
|
|
||||||
own_address_type: (BLE only)
|
connection_parameters_preferences:
|
||||||
|
(BLE only, ignored for BR/EDR)
|
||||||
|
* None: use the 1M PHY with default parameters
|
||||||
|
* map: each entry has a PHY as key and a ConnectionParametersPreferences
|
||||||
|
object as value
|
||||||
|
|
||||||
|
own_address_type:
|
||||||
|
(BLE only, ignored for BR/EDR)
|
||||||
|
OwnAddressType.RANDOM to use this device's random address, or
|
||||||
|
OwnAddressType.PUBLIC to use this device's public address.
|
||||||
|
|
||||||
|
timeout:
|
||||||
|
Maximum time to wait for a connection to be established, in seconds.
|
||||||
|
Pass None for an unlimited time.
|
||||||
|
|
||||||
|
always_resolve:
|
||||||
|
(BLE only, ignored for BR/EDR)
|
||||||
|
If True, always initiate a scan, resolving addresses, and connect to the
|
||||||
|
address that resolves to `peer_address`.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Check parameters
|
# Check parameters
|
||||||
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
|
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
|
||||||
raise ValueError('invalid transport')
|
raise InvalidArgumentError('invalid transport')
|
||||||
|
|
||||||
# Adjust the transport automatically if we need to
|
# Adjust the transport automatically if we need to
|
||||||
if transport == BT_LE_TRANSPORT and not self.le_enabled:
|
if transport == BT_LE_TRANSPORT and not self.le_enabled:
|
||||||
@@ -2898,11 +3071,19 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
if isinstance(peer_address, str):
|
if isinstance(peer_address, str):
|
||||||
try:
|
try:
|
||||||
peer_address = Address.from_string_for_transport(
|
if transport == BT_LE_TRANSPORT and peer_address.endswith('@'):
|
||||||
peer_address, transport
|
peer_address = Address.from_string_for_transport(
|
||||||
)
|
peer_address[:-1], transport
|
||||||
except ValueError:
|
)
|
||||||
|
always_resolve = True
|
||||||
|
logger.debug('forcing address resolution')
|
||||||
|
else:
|
||||||
|
peer_address = Address.from_string_for_transport(
|
||||||
|
peer_address, transport
|
||||||
|
)
|
||||||
|
except (InvalidArgumentError, ValueError):
|
||||||
# If the address is not parsable, assume it is a name instead
|
# If the address is not parsable, assume it is a name instead
|
||||||
|
always_resolve = False
|
||||||
logger.debug('looking for peer by name')
|
logger.debug('looking for peer by name')
|
||||||
peer_address = await self.find_peer_by_name(
|
peer_address = await self.find_peer_by_name(
|
||||||
peer_address, transport
|
peer_address, transport
|
||||||
@@ -2913,10 +3094,16 @@ class Device(CompositeEventEmitter):
|
|||||||
transport == BT_BR_EDR_TRANSPORT
|
transport == BT_BR_EDR_TRANSPORT
|
||||||
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
|
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
|
||||||
):
|
):
|
||||||
raise ValueError('BR/EDR addresses must be PUBLIC')
|
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
|
||||||
|
|
||||||
assert isinstance(peer_address, Address)
|
assert isinstance(peer_address, Address)
|
||||||
|
|
||||||
|
if transport == BT_LE_TRANSPORT and always_resolve:
|
||||||
|
logger.debug('resolving address')
|
||||||
|
peer_address = await self.find_peer_by_identity_address(
|
||||||
|
peer_address
|
||||||
|
) # TODO: timeout
|
||||||
|
|
||||||
def on_connection(connection):
|
def on_connection(connection):
|
||||||
if transport == BT_LE_TRANSPORT or (
|
if transport == BT_LE_TRANSPORT or (
|
||||||
# match BR/EDR connection event against peer address
|
# match BR/EDR connection event against peer address
|
||||||
@@ -2964,7 +3151,7 @@ class Device(CompositeEventEmitter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not phys:
|
if not phys:
|
||||||
raise ValueError('at least one supported PHY needed')
|
raise InvalidArgumentError('at least one supported PHY needed')
|
||||||
|
|
||||||
phy_count = len(phys)
|
phy_count = len(phys)
|
||||||
initiating_phys = phy_list_to_bits(phys)
|
initiating_phys = phy_list_to_bits(phys)
|
||||||
@@ -3036,7 +3223,7 @@ class Device(CompositeEventEmitter):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if HCI_LE_1M_PHY not in connection_parameters_preferences:
|
if HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||||
raise ValueError('1M PHY preferences required')
|
raise InvalidArgumentError('1M PHY preferences required')
|
||||||
|
|
||||||
prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
|
prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
|
||||||
result = await self.send_command(
|
result = await self.send_command(
|
||||||
@@ -3136,7 +3323,7 @@ class Device(CompositeEventEmitter):
|
|||||||
if isinstance(peer_address, str):
|
if isinstance(peer_address, str):
|
||||||
try:
|
try:
|
||||||
peer_address = Address(peer_address)
|
peer_address = Address(peer_address)
|
||||||
except ValueError:
|
except InvalidArgumentError:
|
||||||
# If the address is not parsable, assume it is a name instead
|
# If the address is not parsable, assume it is a name instead
|
||||||
logger.debug('looking for peer by name')
|
logger.debug('looking for peer by name')
|
||||||
peer_address = await self.find_peer_by_name(
|
peer_address = await self.find_peer_by_name(
|
||||||
@@ -3146,7 +3333,7 @@ class Device(CompositeEventEmitter):
|
|||||||
assert isinstance(peer_address, Address)
|
assert isinstance(peer_address, Address)
|
||||||
|
|
||||||
if peer_address == Address.NIL:
|
if peer_address == Address.NIL:
|
||||||
raise ValueError('accept on nil address')
|
raise InvalidArgumentError('accept on nil address')
|
||||||
|
|
||||||
# Create a future so that we can wait for the request
|
# Create a future so that we can wait for the request
|
||||||
pending_request_fut = asyncio.get_running_loop().create_future()
|
pending_request_fut = asyncio.get_running_loop().create_future()
|
||||||
@@ -3259,7 +3446,7 @@ class Device(CompositeEventEmitter):
|
|||||||
if isinstance(peer_address, str):
|
if isinstance(peer_address, str):
|
||||||
try:
|
try:
|
||||||
peer_address = Address(peer_address)
|
peer_address = Address(peer_address)
|
||||||
except ValueError:
|
except InvalidArgumentError:
|
||||||
# If the address is not parsable, assume it is a name instead
|
# If the address is not parsable, assume it is a name instead
|
||||||
logger.debug('looking for peer by name')
|
logger.debug('looking for peer by name')
|
||||||
peer_address = await self.find_peer_by_name(
|
peer_address = await self.find_peer_by_name(
|
||||||
@@ -3302,10 +3489,10 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
|
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
|
||||||
if tx_octets < 0x001B or tx_octets > 0x00FB:
|
if tx_octets < 0x001B or tx_octets > 0x00FB:
|
||||||
raise ValueError('tx_octets must be between 0x001B and 0x00FB')
|
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
|
||||||
|
|
||||||
if tx_time < 0x0148 or tx_time > 0x4290:
|
if tx_time < 0x0148 or tx_time > 0x4290:
|
||||||
raise ValueError('tx_time must be between 0x0148 and 0x4290')
|
raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')
|
||||||
|
|
||||||
return await self.send_command(
|
return await self.send_command(
|
||||||
HCI_LE_Set_Data_Length_Command(
|
HCI_LE_Set_Data_Length_Command(
|
||||||
@@ -3418,15 +3605,26 @@ class Device(CompositeEventEmitter):
|
|||||||
check_result=True,
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def transfer_periodic_sync(
|
||||||
|
self, connection: Connection, sync_handle: int, service_data: int = 0
|
||||||
|
) -> None:
|
||||||
|
return await self.send_command(
|
||||||
|
HCI_LE_Periodic_Advertising_Sync_Transfer_Command(
|
||||||
|
connection_handle=connection.handle,
|
||||||
|
service_data=service_data,
|
||||||
|
sync_handle=sync_handle,
|
||||||
|
),
|
||||||
|
check_result=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT):
|
async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT):
|
||||||
"""
|
"""
|
||||||
Scan for a peer with a give name and return its address and transport
|
Scan for a peer with a given name and return its address.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a future to wait for an address to be found
|
# Create a future to wait for an address to be found
|
||||||
peer_address = asyncio.get_running_loop().create_future()
|
peer_address = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
# Scan/inquire with event handlers to handle scan/inquiry results
|
|
||||||
def on_peer_found(address, ad_data):
|
def on_peer_found(address, ad_data):
|
||||||
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
|
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
|
||||||
if local_name is None:
|
if local_name is None:
|
||||||
@@ -3435,13 +3633,13 @@ class Device(CompositeEventEmitter):
|
|||||||
if local_name.decode('utf-8') == name:
|
if local_name.decode('utf-8') == name:
|
||||||
peer_address.set_result(address)
|
peer_address.set_result(address)
|
||||||
|
|
||||||
handler = None
|
listener = None
|
||||||
was_scanning = self.scanning
|
was_scanning = self.scanning
|
||||||
was_discovering = self.discovering
|
was_discovering = self.discovering
|
||||||
try:
|
try:
|
||||||
if transport == BT_LE_TRANSPORT:
|
if transport == BT_LE_TRANSPORT:
|
||||||
event_name = 'advertisement'
|
event_name = 'advertisement'
|
||||||
handler = self.on(
|
listener = self.on(
|
||||||
event_name,
|
event_name,
|
||||||
lambda advertisement: on_peer_found(
|
lambda advertisement: on_peer_found(
|
||||||
advertisement.address, advertisement.data
|
advertisement.address, advertisement.data
|
||||||
@@ -3453,7 +3651,7 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
elif transport == BT_BR_EDR_TRANSPORT:
|
elif transport == BT_BR_EDR_TRANSPORT:
|
||||||
event_name = 'inquiry_result'
|
event_name = 'inquiry_result'
|
||||||
handler = self.on(
|
listener = self.on(
|
||||||
event_name,
|
event_name,
|
||||||
lambda address, class_of_device, eir_data, rssi: on_peer_found(
|
lambda address, class_of_device, eir_data, rssi: on_peer_found(
|
||||||
address, eir_data
|
address, eir_data
|
||||||
@@ -3467,21 +3665,67 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
return await self.abort_on('flush', peer_address)
|
return await self.abort_on('flush', peer_address)
|
||||||
finally:
|
finally:
|
||||||
if handler is not None:
|
if listener is not None:
|
||||||
self.remove_listener(event_name, handler)
|
self.remove_listener(event_name, listener)
|
||||||
|
|
||||||
if transport == BT_LE_TRANSPORT and not was_scanning:
|
if transport == BT_LE_TRANSPORT and not was_scanning:
|
||||||
await self.stop_scanning()
|
await self.stop_scanning()
|
||||||
elif transport == BT_BR_EDR_TRANSPORT and not was_discovering:
|
elif transport == BT_BR_EDR_TRANSPORT and not was_discovering:
|
||||||
await self.stop_discovery()
|
await self.stop_discovery()
|
||||||
|
|
||||||
|
async def find_peer_by_identity_address(self, identity_address: Address) -> Address:
|
||||||
|
"""
|
||||||
|
Scan for a peer with a resolvable address that can be resolved to a given
|
||||||
|
identity address.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a future to wait for an address to be found
|
||||||
|
peer_address = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
|
def on_peer_found(address, _):
|
||||||
|
if address == identity_address:
|
||||||
|
if not peer_address.done():
|
||||||
|
logger.debug(f'*** Matching public address found for {address}')
|
||||||
|
peer_address.set_result(address)
|
||||||
|
return
|
||||||
|
|
||||||
|
if address.is_resolvable:
|
||||||
|
resolved_address = self.address_resolver.resolve(address)
|
||||||
|
if resolved_address == identity_address:
|
||||||
|
if not peer_address.done():
|
||||||
|
logger.debug(f'*** Matching identity found for {address}')
|
||||||
|
peer_address.set_result(address)
|
||||||
|
return
|
||||||
|
|
||||||
|
was_scanning = self.scanning
|
||||||
|
event_name = 'advertisement'
|
||||||
|
listener = None
|
||||||
|
try:
|
||||||
|
listener = self.on(
|
||||||
|
event_name,
|
||||||
|
lambda advertisement: on_peer_found(
|
||||||
|
advertisement.address, advertisement.data
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.scanning:
|
||||||
|
await self.start_scanning(filter_duplicates=True)
|
||||||
|
|
||||||
|
return await self.abort_on('flush', peer_address)
|
||||||
|
finally:
|
||||||
|
if listener is not None:
|
||||||
|
self.remove_listener(event_name, listener)
|
||||||
|
|
||||||
|
if not was_scanning:
|
||||||
|
await self.stop_scanning()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
|
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
|
||||||
return self.smp_manager.pairing_config_factory
|
return self.smp_manager.pairing_config_factory
|
||||||
|
|
||||||
@pairing_config_factory.setter
|
@pairing_config_factory.setter
|
||||||
def pairing_config_factory(
|
def pairing_config_factory(
|
||||||
self, pairing_config_factory: Callable[[Connection], PairingConfig]
|
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.smp_manager.pairing_config_factory = pairing_config_factory
|
self.smp_manager.pairing_config_factory = pairing_config_factory
|
||||||
|
|
||||||
@@ -3580,7 +3824,7 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
async def encrypt(self, connection, enable=True):
|
async def encrypt(self, connection, enable=True):
|
||||||
if not enable and connection.transport == BT_LE_TRANSPORT:
|
if not enable and connection.transport == BT_LE_TRANSPORT:
|
||||||
raise ValueError('`enable` parameter is classic only.')
|
raise InvalidArgumentError('`enable` parameter is classic only.')
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
pending_encryption = asyncio.get_running_loop().create_future()
|
pending_encryption = asyncio.get_running_loop().create_future()
|
||||||
@@ -3599,11 +3843,12 @@ class Device(CompositeEventEmitter):
|
|||||||
if connection.transport == BT_LE_TRANSPORT:
|
if connection.transport == BT_LE_TRANSPORT:
|
||||||
# Look for a key in the key store
|
# Look for a key in the key store
|
||||||
if self.keystore is None:
|
if self.keystore is None:
|
||||||
raise RuntimeError('no key store')
|
raise InvalidOperationError('no key store')
|
||||||
|
|
||||||
|
logger.debug(f'Looking up key for {connection.peer_address}')
|
||||||
keys = await self.keystore.get(str(connection.peer_address))
|
keys = await self.keystore.get(str(connection.peer_address))
|
||||||
if keys is None:
|
if keys is None:
|
||||||
raise RuntimeError('keys not found in key store')
|
raise InvalidOperationError('keys not found in key store')
|
||||||
|
|
||||||
if keys.ltk is not None:
|
if keys.ltk is not None:
|
||||||
ltk = keys.ltk.value
|
ltk = keys.ltk.value
|
||||||
@@ -3614,7 +3859,7 @@ class Device(CompositeEventEmitter):
|
|||||||
rand = keys.ltk_central.rand
|
rand = keys.ltk_central.rand
|
||||||
ediv = keys.ltk_central.ediv
|
ediv = keys.ltk_central.ediv
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('no LTK found for peer')
|
raise InvalidOperationError('no LTK found for peer')
|
||||||
|
|
||||||
if connection.role != HCI_CENTRAL_ROLE:
|
if connection.role != HCI_CENTRAL_ROLE:
|
||||||
raise InvalidStateError('only centrals can start encryption')
|
raise InvalidStateError('only centrals can start encryption')
|
||||||
@@ -3889,7 +4134,7 @@ class Device(CompositeEventEmitter):
|
|||||||
return cis_link
|
return cis_link
|
||||||
|
|
||||||
# Mypy believes this is reachable when context is an ExitStack.
|
# Mypy believes this is reachable when context is an ExitStack.
|
||||||
raise InvalidStateError('Unreachable')
|
raise UnreachableError()
|
||||||
|
|
||||||
# [LE only]
|
# [LE only]
|
||||||
@experimental('Only for testing.')
|
@experimental('Only for testing.')
|
||||||
@@ -4036,6 +4281,12 @@ class Device(CompositeEventEmitter):
|
|||||||
else self.public_address
|
else self.public_address
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if advertising_set.advertising_parameters.own_address_type in (
|
||||||
|
OwnAddressType.RANDOM,
|
||||||
|
OwnAddressType.PUBLIC,
|
||||||
|
):
|
||||||
|
connection.self_resolvable_address = None
|
||||||
|
|
||||||
# Setup auto-restart of the advertising set if needed.
|
# Setup auto-restart of the advertising set if needed.
|
||||||
if advertising_set.auto_restart:
|
if advertising_set.auto_restart:
|
||||||
connection.once(
|
connection.once(
|
||||||
@@ -4071,12 +4322,23 @@ class Device(CompositeEventEmitter):
|
|||||||
@host_event_handler
|
@host_event_handler
|
||||||
def on_connection(
|
def on_connection(
|
||||||
self,
|
self,
|
||||||
connection_handle,
|
connection_handle: int,
|
||||||
transport,
|
transport: int,
|
||||||
peer_address,
|
peer_address: Address,
|
||||||
role,
|
self_resolvable_address: Optional[Address],
|
||||||
connection_parameters,
|
peer_resolvable_address: Optional[Address],
|
||||||
):
|
role: int,
|
||||||
|
connection_parameters: ConnectionParameters,
|
||||||
|
) -> None:
|
||||||
|
# Convert all-zeros addresses into None.
|
||||||
|
if self_resolvable_address == Address.ANY_RANDOM:
|
||||||
|
self_resolvable_address = None
|
||||||
|
if (
|
||||||
|
peer_resolvable_address == Address.ANY_RANDOM
|
||||||
|
or not peer_address.is_resolved
|
||||||
|
):
|
||||||
|
peer_resolvable_address = None
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'*** Connection: [0x{connection_handle:04X}] '
|
f'*** Connection: [0x{connection_handle:04X}] '
|
||||||
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
|
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
|
||||||
@@ -4097,17 +4359,18 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Resolve the peer address if we can
|
if peer_resolvable_address is None:
|
||||||
peer_resolvable_address = None
|
# Resolve the peer address if we can
|
||||||
if self.address_resolver:
|
if self.address_resolver:
|
||||||
if peer_address.is_resolvable:
|
if peer_address.is_resolvable:
|
||||||
resolved_address = self.address_resolver.resolve(peer_address)
|
resolved_address = self.address_resolver.resolve(peer_address)
|
||||||
if resolved_address is not None:
|
if resolved_address is not None:
|
||||||
logger.debug(f'*** Address resolved as {resolved_address}')
|
logger.debug(f'*** Address resolved as {resolved_address}')
|
||||||
peer_resolvable_address = peer_address
|
peer_resolvable_address = peer_address
|
||||||
peer_address = resolved_address
|
peer_address = resolved_address
|
||||||
|
|
||||||
self_address = None
|
self_address = None
|
||||||
|
own_address_type: Optional[int] = None
|
||||||
if role == HCI_CENTRAL_ROLE:
|
if role == HCI_CENTRAL_ROLE:
|
||||||
own_address_type = self.connect_own_address_type
|
own_address_type = self.connect_own_address_type
|
||||||
assert own_address_type is not None
|
assert own_address_type is not None
|
||||||
@@ -4136,12 +4399,18 @@ class Device(CompositeEventEmitter):
|
|||||||
else self.random_address
|
else self.random_address
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Some controllers may return local resolvable address even not using address
|
||||||
|
# generation offloading. Ignore the value to prevent SMP failure.
|
||||||
|
if own_address_type in (OwnAddressType.RANDOM, OwnAddressType.PUBLIC):
|
||||||
|
self_resolvable_address = None
|
||||||
|
|
||||||
# Create a connection.
|
# Create a connection.
|
||||||
connection = Connection(
|
connection = Connection(
|
||||||
self,
|
self,
|
||||||
connection_handle,
|
connection_handle,
|
||||||
transport,
|
transport,
|
||||||
self_address,
|
self_address,
|
||||||
|
self_resolvable_address,
|
||||||
peer_address,
|
peer_address,
|
||||||
peer_resolvable_address,
|
peer_resolvable_address,
|
||||||
role,
|
role,
|
||||||
@@ -4152,9 +4421,10 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
|
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
|
||||||
if self.legacy_advertiser.auto_restart:
|
if self.legacy_advertiser.auto_restart:
|
||||||
|
advertiser = self.legacy_advertiser
|
||||||
connection.once(
|
connection.once(
|
||||||
'disconnection',
|
'disconnection',
|
||||||
lambda _: self.abort_on('flush', self.legacy_advertiser.start()),
|
lambda _: self.abort_on('flush', advertiser.start()),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.legacy_advertiser = None
|
self.legacy_advertiser = None
|
||||||
@@ -4377,7 +4647,7 @@ class Device(CompositeEventEmitter):
|
|||||||
return await pairing_config.delegate.confirm(auto=True)
|
return await pairing_config.delegate.confirm(auto=True)
|
||||||
|
|
||||||
async def na() -> bool:
|
async def na() -> bool:
|
||||||
assert False, "N/A: unreachable"
|
raise UnreachableError()
|
||||||
|
|
||||||
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
|
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
|
||||||
methods = {
|
methods = {
|
||||||
@@ -4838,5 +5108,6 @@ class Device(CompositeEventEmitter):
|
|||||||
return (
|
return (
|
||||||
f'Device(name="{self.name}", '
|
f'Device(name="{self.name}", '
|
||||||
f'random_address="{self.random_address}", '
|
f'random_address="{self.random_address}", '
|
||||||
f'public_address="{self.public_address}")'
|
f'public_address="{self.public_address}", '
|
||||||
|
f'static_address="{self.static_address}")'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from typing import Tuple
|
|||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
from bumble.hci import (
|
from bumble.hci import (
|
||||||
hci_vendor_command_op_code,
|
hci_vendor_command_op_code,
|
||||||
STATUS_SPEC,
|
STATUS_SPEC,
|
||||||
@@ -49,6 +50,10 @@ from bumble.drivers import common
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RtkFirmwareError(core.BaseBumbleError):
|
||||||
|
"""Error raised when RTK firmware initialization fails."""
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -208,15 +213,15 @@ class Firmware:
|
|||||||
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
||||||
|
|
||||||
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
||||||
raise ValueError("Firmware does not start with epatch signature")
|
raise RtkFirmwareError("Firmware does not start with epatch signature")
|
||||||
|
|
||||||
if not firmware.endswith(extension_sig):
|
if not firmware.endswith(extension_sig):
|
||||||
raise ValueError("Firmware does not end with extension sig")
|
raise RtkFirmwareError("Firmware does not end with extension sig")
|
||||||
|
|
||||||
# The firmware should start with a 14 byte header.
|
# The firmware should start with a 14 byte header.
|
||||||
epatch_header_size = 14
|
epatch_header_size = 14
|
||||||
if len(firmware) < epatch_header_size:
|
if len(firmware) < epatch_header_size:
|
||||||
raise ValueError("Firmware too short")
|
raise RtkFirmwareError("Firmware too short")
|
||||||
|
|
||||||
# Look for the "project ID", starting from the end.
|
# Look for the "project ID", starting from the end.
|
||||||
offset = len(firmware) - len(extension_sig)
|
offset = len(firmware) - len(extension_sig)
|
||||||
@@ -230,7 +235,7 @@ class Firmware:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if length == 0:
|
if length == 0:
|
||||||
raise ValueError("Invalid 0-length instruction")
|
raise RtkFirmwareError("Invalid 0-length instruction")
|
||||||
|
|
||||||
if opcode == 0 and length == 1:
|
if opcode == 0 and length == 1:
|
||||||
project_id = firmware[offset - 1]
|
project_id = firmware[offset - 1]
|
||||||
@@ -239,7 +244,7 @@ class Firmware:
|
|||||||
offset -= length
|
offset -= length
|
||||||
|
|
||||||
if project_id < 0:
|
if project_id < 0:
|
||||||
raise ValueError("Project ID not found")
|
raise RtkFirmwareError("Project ID not found")
|
||||||
|
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
|
|
||||||
@@ -252,7 +257,7 @@ class Firmware:
|
|||||||
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
||||||
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
||||||
if epatch_header_size + 8 * num_patches > len(firmware):
|
if epatch_header_size + 8 * num_patches > len(firmware):
|
||||||
raise ValueError("Firmware too short")
|
raise RtkFirmwareError("Firmware too short")
|
||||||
chip_id_table_offset = epatch_header_size
|
chip_id_table_offset = epatch_header_size
|
||||||
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
||||||
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
||||||
@@ -266,7 +271,7 @@ class Firmware:
|
|||||||
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
||||||
)
|
)
|
||||||
if patch_offset + patch_length > len(firmware):
|
if patch_offset + patch_length > len(firmware):
|
||||||
raise ValueError("Firmware too short")
|
raise RtkFirmwareError("Firmware too short")
|
||||||
|
|
||||||
# Get the SVN version for the patch
|
# Get the SVN version for the patch
|
||||||
(svn_version,) = struct.unpack_from(
|
(svn_version,) = struct.unpack_from(
|
||||||
@@ -645,7 +650,7 @@ class Driver(common.Driver):
|
|||||||
):
|
):
|
||||||
return await self.download_for_rtl8723b()
|
return await self.download_for_rtl8723b()
|
||||||
|
|
||||||
raise ValueError("ROM not supported")
|
raise RtkFirmwareError("ROM not supported")
|
||||||
|
|
||||||
async def init_controller(self):
|
async def init_controller(self):
|
||||||
await self.download_firmware()
|
await self.download_firmware()
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.core import UUID
|
from bumble.core import BaseBumbleError, UUID
|
||||||
from bumble.att import Attribute, AttributeValue
|
from bumble.att import Attribute, AttributeValue
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -320,6 +320,11 @@ def show_services(services: Iterable[Service]) -> None:
|
|||||||
print(color(' ' + str(descriptor), 'green'))
|
print(color(' ' + str(descriptor), 'green'))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class InvalidServiceError(BaseBumbleError):
|
||||||
|
"""The service is not compliant with the spec/profile"""
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Service(Attribute):
|
class Service(Attribute):
|
||||||
'''
|
'''
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class ProfileServiceProxy:
|
|||||||
SERVICE_CLASS: Type[TemplateService]
|
SERVICE_CLASS: Type[TemplateService]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_client(cls, client: Client) -> ProfileServiceProxy:
|
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
|
||||||
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
||||||
|
|
||||||
|
|
||||||
@@ -283,6 +283,8 @@ class Client:
|
|||||||
self.services = []
|
self.services = []
|
||||||
self.cached_values = {}
|
self.cached_values = {}
|
||||||
|
|
||||||
|
connection.on('disconnection', self.on_disconnection)
|
||||||
|
|
||||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||||
|
|
||||||
@@ -331,9 +333,9 @@ class Client:
|
|||||||
async def request_mtu(self, mtu: int) -> int:
|
async def request_mtu(self, mtu: int) -> int:
|
||||||
# Check the range
|
# Check the range
|
||||||
if mtu < ATT_DEFAULT_MTU:
|
if mtu < ATT_DEFAULT_MTU:
|
||||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||||
if mtu > 0xFFFF:
|
if mtu > 0xFFFF:
|
||||||
raise ValueError('MTU must be <= 0xFFFF')
|
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
|
||||||
|
|
||||||
# We can only send one request per connection
|
# We can only send one request per connection
|
||||||
if self.mtu_exchange_done:
|
if self.mtu_exchange_done:
|
||||||
@@ -405,7 +407,7 @@ class Client:
|
|||||||
if not already_known:
|
if not already_known:
|
||||||
self.services.append(service)
|
self.services.append(service)
|
||||||
|
|
||||||
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
|
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||||
'''
|
'''
|
||||||
@@ -1072,6 +1074,10 @@ class Client:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_disconnection(self, _) -> None:
|
||||||
|
if self.pending_response and not self.pending_response.done():
|
||||||
|
self.pending_response.cancel()
|
||||||
|
|
||||||
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from bumble.core import (
|
|||||||
BT_BR_EDR_TRANSPORT,
|
BT_BR_EDR_TRANSPORT,
|
||||||
AdvertisingData,
|
AdvertisingData,
|
||||||
DeviceClass,
|
DeviceClass,
|
||||||
|
InvalidArgumentError,
|
||||||
|
InvalidPacketError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
bit_flags_to_strings,
|
bit_flags_to_strings,
|
||||||
name_or_number,
|
name_or_number,
|
||||||
@@ -92,14 +94,14 @@ def map_class_of_device(class_of_device):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def phy_list_to_bits(phys):
|
def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int:
|
||||||
if phys is None:
|
if phys is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
phy_bits = 0
|
phy_bits = 0
|
||||||
for phy in phys:
|
for phy in phys:
|
||||||
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
|
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
|
||||||
raise ValueError('invalid PHY')
|
raise InvalidArgumentError('invalid PHY')
|
||||||
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
|
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
|
||||||
return phy_bits
|
return phy_bits
|
||||||
|
|
||||||
@@ -1553,7 +1555,7 @@ class HCI_Object:
|
|||||||
new_offset, field_value = field_type(data, offset)
|
new_offset, field_value = field_type(data, offset)
|
||||||
return (field_value, new_offset - offset)
|
return (field_value, new_offset - offset)
|
||||||
|
|
||||||
raise ValueError(f'unknown field type {field_type}')
|
raise InvalidArgumentError(f'unknown field type {field_type}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dict_from_bytes(data, offset, fields):
|
def dict_from_bytes(data, offset, fields):
|
||||||
@@ -1622,7 +1624,7 @@ class HCI_Object:
|
|||||||
if 0 <= field_value <= 255:
|
if 0 <= field_value <= 255:
|
||||||
field_bytes = bytes([field_value])
|
field_bytes = bytes([field_value])
|
||||||
else:
|
else:
|
||||||
raise ValueError('value too large for *-typed field')
|
raise InvalidArgumentError('value too large for *-typed field')
|
||||||
else:
|
else:
|
||||||
field_bytes = bytes(field_value)
|
field_bytes = bytes(field_value)
|
||||||
elif field_type == 'v':
|
elif field_type == 'v':
|
||||||
@@ -1641,7 +1643,9 @@ class HCI_Object:
|
|||||||
elif len(field_bytes) > field_type:
|
elif len(field_bytes) > field_type:
|
||||||
field_bytes = field_bytes[:field_type]
|
field_bytes = field_bytes[:field_type]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"don't know how to serialize type {type(field_value)}")
|
raise InvalidArgumentError(
|
||||||
|
f"don't know how to serialize type {type(field_value)}"
|
||||||
|
)
|
||||||
|
|
||||||
return field_bytes
|
return field_bytes
|
||||||
|
|
||||||
@@ -1835,6 +1839,12 @@ class Address:
|
|||||||
data, offset, Address.PUBLIC_DEVICE_ADDRESS
|
data, offset, Address.PUBLIC_DEVICE_ADDRESS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_random_address(data, offset):
|
||||||
|
return Address.parse_address_with_type(
|
||||||
|
data, offset, Address.RANDOM_DEVICE_ADDRESS
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_address_with_type(data, offset, address_type):
|
def parse_address_with_type(data, offset, address_type):
|
||||||
return offset + 6, Address(data[offset : offset + 6], address_type)
|
return offset + 6, Address(data[offset : offset + 6], address_type)
|
||||||
@@ -1905,7 +1915,7 @@ class Address:
|
|||||||
self.address_bytes = bytes(reversed(bytes.fromhex(address)))
|
self.address_bytes = bytes(reversed(bytes.fromhex(address)))
|
||||||
|
|
||||||
if len(self.address_bytes) != 6:
|
if len(self.address_bytes) != 6:
|
||||||
raise ValueError('invalid address length')
|
raise InvalidArgumentError('invalid address length')
|
||||||
|
|
||||||
self.address_type = address_type
|
self.address_type = address_type
|
||||||
|
|
||||||
@@ -1961,7 +1971,8 @@ class Address:
|
|||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
return (
|
||||||
self.address_bytes == other.address_bytes
|
isinstance(other, Address)
|
||||||
|
and self.address_bytes == other.address_bytes
|
||||||
and self.is_public == other.is_public
|
and self.is_public == other.is_public
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2108,7 +2119,7 @@ class HCI_Command(HCI_Packet):
|
|||||||
op_code, length = struct.unpack_from('<HB', packet, 1)
|
op_code, length = struct.unpack_from('<HB', packet, 1)
|
||||||
parameters = packet[4:]
|
parameters = packet[4:]
|
||||||
if len(parameters) != length:
|
if len(parameters) != length:
|
||||||
raise ValueError('invalid packet length')
|
raise InvalidPacketError('invalid packet length')
|
||||||
|
|
||||||
# Look for a registered class
|
# Look for a registered class
|
||||||
cls = HCI_Command.command_classes.get(op_code)
|
cls = HCI_Command.command_classes.get(op_code)
|
||||||
@@ -4518,18 +4529,6 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
|
|
||||||
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
|
|
||||||
'''
|
|
||||||
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
|
|
||||||
'''
|
|
||||||
|
|
||||||
class Enable(enum.IntFlag):
|
|
||||||
REPORTING_ENABLED = 1 << 0
|
|
||||||
DUPLICATE_FILTERING_ENABLED = 1 << 1
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@HCI_Command.command(
|
@HCI_Command.command(
|
||||||
[
|
[
|
||||||
@@ -4565,6 +4564,32 @@ class HCI_LE_Set_Privacy_Mode_Command(HCI_Command):
|
|||||||
return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode)
|
return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@HCI_Command.command([('sync_handle', 2), ('enable', 1)])
|
||||||
|
class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command):
|
||||||
|
'''
|
||||||
|
See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command
|
||||||
|
'''
|
||||||
|
|
||||||
|
class Enable(enum.IntFlag):
|
||||||
|
REPORTING_ENABLED = 1 << 0
|
||||||
|
DUPLICATE_FILTERING_ENABLED = 1 << 1
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@HCI_Command.command(
|
||||||
|
fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)],
|
||||||
|
return_parameters_fields=[
|
||||||
|
('status', STATUS_SPEC),
|
||||||
|
('connection_handle', 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command):
|
||||||
|
'''
|
||||||
|
See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@HCI_Command.command(
|
@HCI_Command.command(
|
||||||
fields=[
|
fields=[
|
||||||
@@ -4807,7 +4832,7 @@ class HCI_Event(HCI_Packet):
|
|||||||
length = packet[2]
|
length = packet[2]
|
||||||
parameters = packet[3:]
|
parameters = packet[3:]
|
||||||
if len(parameters) != length:
|
if len(parameters) != length:
|
||||||
raise ValueError('invalid packet length')
|
raise InvalidPacketError('invalid packet length')
|
||||||
|
|
||||||
cls: Any
|
cls: Any
|
||||||
if event_code == HCI_LE_META_EVENT:
|
if event_code == HCI_LE_META_EVENT:
|
||||||
@@ -5174,8 +5199,8 @@ class HCI_LE_Data_Length_Change_Event(HCI_LE_Meta_Event):
|
|||||||
),
|
),
|
||||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||||
('peer_address', Address.parse_address_preceded_by_type),
|
('peer_address', Address.parse_address_preceded_by_type),
|
||||||
('local_resolvable_private_address', Address.parse_address),
|
('local_resolvable_private_address', Address.parse_random_address),
|
||||||
('peer_resolvable_private_address', Address.parse_address),
|
('peer_resolvable_private_address', Address.parse_random_address),
|
||||||
('connection_interval', 2),
|
('connection_interval', 2),
|
||||||
('peripheral_latency', 2),
|
('peripheral_latency', 2),
|
||||||
('supervision_timeout', 2),
|
('supervision_timeout', 2),
|
||||||
@@ -6342,7 +6367,7 @@ class HCI_AclDataPacket(HCI_Packet):
|
|||||||
bc_flag = (h >> 14) & 3
|
bc_flag = (h >> 14) & 3
|
||||||
data = packet[5:]
|
data = packet[5:]
|
||||||
if len(data) != data_total_length:
|
if len(data) != data_total_length:
|
||||||
raise ValueError('invalid packet length')
|
raise InvalidPacketError('invalid packet length')
|
||||||
return HCI_AclDataPacket(
|
return HCI_AclDataPacket(
|
||||||
connection_handle, pb_flag, bc_flag, data_total_length, data
|
connection_handle, pb_flag, bc_flag, data_total_length, data
|
||||||
)
|
)
|
||||||
@@ -6390,7 +6415,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
|||||||
packet_status = (h >> 12) & 0b11
|
packet_status = (h >> 12) & 0b11
|
||||||
data = packet[4:]
|
data = packet[4:]
|
||||||
if len(data) != data_total_length:
|
if len(data) != data_total_length:
|
||||||
raise ValueError(
|
raise InvalidPacketError(
|
||||||
f'invalid packet length {len(data)} != {data_total_length}'
|
f'invalid packet length {len(data)} != {data_total_length}'
|
||||||
)
|
)
|
||||||
return HCI_SynchronousDataPacket(
|
return HCI_SynchronousDataPacket(
|
||||||
|
|||||||
@@ -23,13 +23,12 @@ import struct
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
from typing import Optional, Callable, TYPE_CHECKING
|
from typing import Optional, Callable
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from bumble import l2cap, device
|
from bumble import l2cap, device
|
||||||
from bumble.colors import color
|
|
||||||
from bumble.core import InvalidStateError, ProtocolError
|
from bumble.core import InvalidStateError, ProtocolError
|
||||||
from .hci import Address
|
from bumble.hci import Address
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
|
|||||||
async def connect_control_channel(self) -> None:
|
async def connect_control_channel(self) -> None:
|
||||||
# Create a new L2CAP connection - control channel
|
# Create a new L2CAP connection - control channel
|
||||||
try:
|
try:
|
||||||
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
|
channel = await self.device.l2cap_channel_manager.connect(
|
||||||
self.connection, HID_CONTROL_PSM
|
self.connection, HID_CONTROL_PSM
|
||||||
)
|
)
|
||||||
|
channel.sink = self.on_ctrl_pdu
|
||||||
|
self.l2cap_ctrl_channel = channel
|
||||||
except ProtocolError:
|
except ProtocolError:
|
||||||
logging.exception(f'L2CAP connection failed.')
|
logging.exception(f'L2CAP connection failed.')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert self.l2cap_ctrl_channel is not None
|
|
||||||
# Become a sink for the L2CAP channel
|
|
||||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
|
||||||
|
|
||||||
async def connect_interrupt_channel(self) -> None:
|
async def connect_interrupt_channel(self) -> None:
|
||||||
# Create a new L2CAP connection - interrupt channel
|
# Create a new L2CAP connection - interrupt channel
|
||||||
try:
|
try:
|
||||||
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
|
channel = await self.device.l2cap_channel_manager.connect(
|
||||||
self.connection, HID_INTERRUPT_PSM
|
self.connection, HID_INTERRUPT_PSM
|
||||||
)
|
)
|
||||||
|
channel.sink = self.on_intr_pdu
|
||||||
|
self.l2cap_intr_channel = channel
|
||||||
except ProtocolError:
|
except ProtocolError:
|
||||||
logging.exception(f'L2CAP connection failed.')
|
logging.exception(f'L2CAP connection failed.')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert self.l2cap_intr_channel is not None
|
|
||||||
# Become a sink for the L2CAP channel
|
|
||||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
|
||||||
|
|
||||||
async def disconnect_interrupt_channel(self) -> None:
|
async def disconnect_interrupt_channel(self) -> None:
|
||||||
if self.l2cap_intr_channel is None:
|
if self.l2cap_intr_channel is None:
|
||||||
raise InvalidStateError('invalid state')
|
raise InvalidStateError('invalid state')
|
||||||
@@ -334,17 +329,18 @@ class Device(HID):
|
|||||||
ERR_INVALID_PARAMETER = 0x04
|
ERR_INVALID_PARAMETER = 0x04
|
||||||
SUCCESS = 0xFF
|
SUCCESS = 0xFF
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class GetSetStatus:
|
class GetSetStatus:
|
||||||
def __init__(self) -> None:
|
data: bytes = b''
|
||||||
self.data = bytearray()
|
status: int = 0
|
||||||
self.status = 0
|
|
||||||
|
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
|
||||||
|
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
|
||||||
|
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
|
||||||
|
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
|
||||||
|
|
||||||
def __init__(self, device: device.Device) -> None:
|
def __init__(self, device: device.Device) -> None:
|
||||||
super().__init__(device, HID.Role.DEVICE)
|
super().__init__(device, HID.Role.DEVICE)
|
||||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
|
||||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
|
||||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
|
||||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||||
@@ -410,7 +406,6 @@ class Device(HID):
|
|||||||
buffer_size = 0
|
buffer_size = 0
|
||||||
|
|
||||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.FAILURE:
|
if ret.status == self.GetSetReturn.FAILURE:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||||
@@ -428,7 +423,9 @@ class Device(HID):
|
|||||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
def register_get_report_cb(
|
||||||
|
self, cb: Callable[[int, int, int], Device.GetSetStatus]
|
||||||
|
) -> None:
|
||||||
self.get_report_cb = cb
|
self.get_report_cb = cb
|
||||||
logger.debug("GetReport callback registered successfully")
|
logger.debug("GetReport callback registered successfully")
|
||||||
|
|
||||||
@@ -442,7 +439,6 @@ class Device(HID):
|
|||||||
report_data = pdu[2:]
|
report_data = pdu[2:]
|
||||||
report_size = len(report_data) + 1
|
report_size = len(report_data) + 1
|
||||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||||
@@ -453,7 +449,7 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_report_cb(
|
def register_set_report_cb(
|
||||||
self, cb: Callable[[int, int, int, bytes], None]
|
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.set_report_cb = cb
|
self.set_report_cb = cb
|
||||||
logger.debug("SetReport callback registered successfully")
|
logger.debug("SetReport callback registered successfully")
|
||||||
@@ -464,13 +460,12 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
return
|
return
|
||||||
ret = self.get_protocol_cb()
|
ret = self.get_protocol_cb()
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||||
else:
|
else:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
|
||||||
self.get_protocol_cb = cb
|
self.get_protocol_cb = cb
|
||||||
logger.debug("GetProtocol callback registered successfully")
|
logger.debug("GetProtocol callback registered successfully")
|
||||||
|
|
||||||
@@ -480,13 +475,14 @@ class Device(HID):
|
|||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
return
|
return
|
||||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||||
assert ret is not None
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
else:
|
else:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
def register_set_protocol_cb(
|
||||||
|
self, cb: Callable[[int], Device.GetSetStatus]
|
||||||
|
) -> None:
|
||||||
self.set_protocol_cb = cb
|
self.set_protocol_cb = cb
|
||||||
logger.debug("SetProtocol callback registered successfully")
|
logger.debug("SetProtocol callback registered successfully")
|
||||||
|
|
||||||
|
|||||||
@@ -772,6 +772,8 @@ class Host(AbortableEventEmitter):
|
|||||||
event.connection_handle,
|
event.connection_handle,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
event.peer_address,
|
event.peer_address,
|
||||||
|
getattr(event, 'local_resolvable_private_address', None),
|
||||||
|
getattr(event, 'peer_resolvable_private_address', None),
|
||||||
event.role,
|
event.role,
|
||||||
connection_parameters,
|
connection_parameters,
|
||||||
)
|
)
|
||||||
@@ -817,6 +819,8 @@ class Host(AbortableEventEmitter):
|
|||||||
event.bd_addr,
|
event.bd_addr,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
|
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
|
||||||
|
|||||||
@@ -41,7 +41,14 @@ from typing import (
|
|||||||
|
|
||||||
from .utils import deprecated
|
from .utils import deprecated
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
|
from .core import (
|
||||||
|
BT_CENTRAL_ROLE,
|
||||||
|
InvalidStateError,
|
||||||
|
InvalidArgumentError,
|
||||||
|
InvalidPacketError,
|
||||||
|
OutOfResourcesError,
|
||||||
|
ProtocolError,
|
||||||
|
)
|
||||||
from .hci import (
|
from .hci import (
|
||||||
HCI_LE_Connection_Update_Command,
|
HCI_LE_Connection_Update_Command,
|
||||||
HCI_Object,
|
HCI_Object,
|
||||||
@@ -189,17 +196,17 @@ class LeCreditBasedChannelSpec:
|
|||||||
self.max_credits < 1
|
self.max_credits < 1
|
||||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||||
):
|
):
|
||||||
raise ValueError('max credits out of range')
|
raise InvalidArgumentError('max credits out of range')
|
||||||
if (
|
if (
|
||||||
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
||||||
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
||||||
):
|
):
|
||||||
raise ValueError('MTU out of range')
|
raise InvalidArgumentError('MTU out of range')
|
||||||
if (
|
if (
|
||||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||||
):
|
):
|
||||||
raise ValueError('MPS out of range')
|
raise InvalidArgumentError('MPS out of range')
|
||||||
|
|
||||||
|
|
||||||
class L2CAP_PDU:
|
class L2CAP_PDU:
|
||||||
@@ -211,7 +218,7 @@ class L2CAP_PDU:
|
|||||||
def from_bytes(data: bytes) -> L2CAP_PDU:
|
def from_bytes(data: bytes) -> L2CAP_PDU:
|
||||||
# Check parameters
|
# Check parameters
|
||||||
if len(data) < 4:
|
if len(data) < 4:
|
||||||
raise ValueError('not enough data for L2CAP header')
|
raise InvalidPacketError('not enough data for L2CAP header')
|
||||||
|
|
||||||
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
|
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
|
||||||
l2cap_pdu_payload = data[4:]
|
l2cap_pdu_payload = data[4:]
|
||||||
@@ -816,7 +823,7 @@ class ClassicChannel(EventEmitter):
|
|||||||
|
|
||||||
# Check that we can start a new connection
|
# Check that we can start a new connection
|
||||||
if self.connection_result:
|
if self.connection_result:
|
||||||
raise RuntimeError('connection already pending')
|
raise InvalidStateError('connection already pending')
|
||||||
|
|
||||||
self._change_state(self.State.WAIT_CONNECT_RSP)
|
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
@@ -1129,7 +1136,7 @@ class LeCreditBasedChannel(EventEmitter):
|
|||||||
# Check that we can start a new connection
|
# Check that we can start a new connection
|
||||||
identifier = self.manager.next_identifier(self.connection)
|
identifier = self.manager.next_identifier(self.connection)
|
||||||
if identifier in self.manager.le_coc_requests:
|
if identifier in self.manager.le_coc_requests:
|
||||||
raise RuntimeError('too many concurrent connection requests')
|
raise InvalidStateError('too many concurrent connection requests')
|
||||||
|
|
||||||
self._change_state(self.State.CONNECTING)
|
self._change_state(self.State.CONNECTING)
|
||||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||||
@@ -1516,7 +1523,7 @@ class ChannelManager:
|
|||||||
if cid not in channels:
|
if cid not in channels:
|
||||||
return cid
|
return cid
|
||||||
|
|
||||||
raise RuntimeError('no free CID available')
|
raise OutOfResourcesError('no free CID available')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_free_le_cid(channels: Iterable[int]) -> int:
|
def find_free_le_cid(channels: Iterable[int]) -> int:
|
||||||
@@ -1529,7 +1536,7 @@ class ChannelManager:
|
|||||||
if cid not in channels:
|
if cid not in channels:
|
||||||
return cid
|
return cid
|
||||||
|
|
||||||
raise RuntimeError('no free CID')
|
raise OutOfResourcesError('no free CID')
|
||||||
|
|
||||||
def next_identifier(self, connection: Connection) -> int:
|
def next_identifier(self, connection: Connection) -> int:
|
||||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||||
@@ -1576,15 +1583,15 @@ class ChannelManager:
|
|||||||
else:
|
else:
|
||||||
# Check that the PSM isn't already in use
|
# Check that the PSM isn't already in use
|
||||||
if spec.psm in self.servers:
|
if spec.psm in self.servers:
|
||||||
raise ValueError('PSM already in use')
|
raise InvalidArgumentError('PSM already in use')
|
||||||
|
|
||||||
# Check that the PSM is valid
|
# Check that the PSM is valid
|
||||||
if spec.psm % 2 == 0:
|
if spec.psm % 2 == 0:
|
||||||
raise ValueError('invalid PSM (not odd)')
|
raise InvalidArgumentError('invalid PSM (not odd)')
|
||||||
check = spec.psm >> 8
|
check = spec.psm >> 8
|
||||||
while check:
|
while check:
|
||||||
if check % 2 != 0:
|
if check % 2 != 0:
|
||||||
raise ValueError('invalid PSM')
|
raise InvalidArgumentError('invalid PSM')
|
||||||
check >>= 8
|
check >>= 8
|
||||||
|
|
||||||
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
|
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
|
||||||
@@ -1626,7 +1633,7 @@ class ChannelManager:
|
|||||||
else:
|
else:
|
||||||
# Check that the PSM isn't already in use
|
# Check that the PSM isn't already in use
|
||||||
if spec.psm in self.le_coc_servers:
|
if spec.psm in self.le_coc_servers:
|
||||||
raise ValueError('PSM already in use')
|
raise InvalidArgumentError('PSM already in use')
|
||||||
|
|
||||||
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
|
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
|
||||||
self,
|
self,
|
||||||
@@ -2154,10 +2161,10 @@ class ChannelManager:
|
|||||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||||
source_cid = self.find_free_le_cid(connection_channels)
|
source_cid = self.find_free_le_cid(connection_channels)
|
||||||
if source_cid is None: # Should never happen!
|
if source_cid is None: # Should never happen!
|
||||||
raise RuntimeError('all CIDs already in use')
|
raise OutOfResourcesError('all CIDs already in use')
|
||||||
|
|
||||||
if spec.psm is None:
|
if spec.psm is None:
|
||||||
raise ValueError('PSM cannot be None')
|
raise InvalidArgumentError('PSM cannot be None')
|
||||||
|
|
||||||
# Create the channel
|
# Create the channel
|
||||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
|
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
|
||||||
@@ -2206,10 +2213,10 @@ class ChannelManager:
|
|||||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||||
source_cid = self.find_free_br_edr_cid(connection_channels)
|
source_cid = self.find_free_br_edr_cid(connection_channels)
|
||||||
if source_cid is None: # Should never happen!
|
if source_cid is None: # Should never happen!
|
||||||
raise RuntimeError('all CIDs already in use')
|
raise OutOfResourcesError('all CIDs already in use')
|
||||||
|
|
||||||
if spec.psm is None:
|
if spec.psm is None:
|
||||||
raise ValueError('PSM cannot be None')
|
raise InvalidArgumentError('PSM cannot be None')
|
||||||
|
|
||||||
# Create the channel
|
# Create the channel
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
|
from bumble.core import (
|
||||||
|
BT_PERIPHERAL_ROLE,
|
||||||
|
BT_BR_EDR_TRANSPORT,
|
||||||
|
BT_LE_TRANSPORT,
|
||||||
|
InvalidStateError,
|
||||||
|
)
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.hci import (
|
from bumble.hci import (
|
||||||
Address,
|
Address,
|
||||||
@@ -405,12 +410,12 @@ class RemoteLink:
|
|||||||
|
|
||||||
def add_controller(self, controller):
|
def add_controller(self, controller):
|
||||||
if self.controller:
|
if self.controller:
|
||||||
raise ValueError('controller already set')
|
raise InvalidStateError('controller already set')
|
||||||
self.controller = controller
|
self.controller = controller
|
||||||
|
|
||||||
def remove_controller(self, controller):
|
def remove_controller(self, controller):
|
||||||
if self.controller != controller:
|
if self.controller != controller:
|
||||||
raise ValueError('controller mismatch')
|
raise InvalidStateError('controller mismatch')
|
||||||
self.controller = None
|
self.controller = None
|
||||||
|
|
||||||
def get_pending_connection(self):
|
def get_pending_connection(self):
|
||||||
|
|||||||
739
bumble/profiles/ascs.py
Normal file
739
bumble/profiles/ascs.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
# Copyright 2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for
|
||||||
|
|
||||||
|
"""LE Audio - Audio Stream Control Service"""
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
from bumble import colors
|
||||||
|
from bumble.profiles.bap import CodecSpecificConfiguration
|
||||||
|
from bumble.profiles import le_audio
|
||||||
|
from bumble import device
|
||||||
|
from bumble import gatt
|
||||||
|
from bumble import gatt_client
|
||||||
|
from bumble import hci
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# ASE Operations
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ASE_Operation:
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service - 5 ASE Control operations.
|
||||||
|
'''
|
||||||
|
|
||||||
|
classes: Dict[int, Type[ASE_Operation]] = {}
|
||||||
|
op_code: int
|
||||||
|
name: str
|
||||||
|
fields: Optional[Sequence[Any]] = None
|
||||||
|
ase_id: List[int]
|
||||||
|
|
||||||
|
class Opcode(enum.IntEnum):
|
||||||
|
# fmt: off
|
||||||
|
CONFIG_CODEC = 0x01
|
||||||
|
CONFIG_QOS = 0x02
|
||||||
|
ENABLE = 0x03
|
||||||
|
RECEIVER_START_READY = 0x04
|
||||||
|
DISABLE = 0x05
|
||||||
|
RECEIVER_STOP_READY = 0x06
|
||||||
|
UPDATE_METADATA = 0x07
|
||||||
|
RELEASE = 0x08
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bytes(pdu: bytes) -> ASE_Operation:
|
||||||
|
op_code = pdu[0]
|
||||||
|
|
||||||
|
cls = ASE_Operation.classes.get(op_code)
|
||||||
|
if cls is None:
|
||||||
|
instance = ASE_Operation(pdu)
|
||||||
|
instance.name = ASE_Operation.Opcode(op_code).name
|
||||||
|
instance.op_code = op_code
|
||||||
|
return instance
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
ASE_Operation.__init__(self, pdu)
|
||||||
|
if self.fields is not None:
|
||||||
|
self.init_from_bytes(pdu, 1)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def subclass(fields):
|
||||||
|
def inner(cls: Type[ASE_Operation]):
|
||||||
|
try:
|
||||||
|
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
|
||||||
|
cls.name = operation.name
|
||||||
|
cls.op_code = operation
|
||||||
|
except:
|
||||||
|
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
|
||||||
|
cls.fields = fields
|
||||||
|
|
||||||
|
# Register a factory for this class
|
||||||
|
ASE_Operation.classes[cls.op_code] = cls
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
|
||||||
|
if self.fields is not None and kwargs:
|
||||||
|
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||||
|
if pdu is None:
|
||||||
|
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
||||||
|
kwargs, self.fields
|
||||||
|
)
|
||||||
|
self.pdu = pdu
|
||||||
|
|
||||||
|
def init_from_bytes(self, pdu: bytes, offset: int):
|
||||||
|
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return self.pdu
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
result = f'{colors.color(self.name, "yellow")} '
|
||||||
|
if fields := getattr(self, 'fields', None):
|
||||||
|
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
|
||||||
|
else:
|
||||||
|
if len(self.pdu) > 1:
|
||||||
|
result += f': {self.pdu.hex()}'
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
('ase_id', 1),
|
||||||
|
('target_latency', 1),
|
||||||
|
('target_phy', 1),
|
||||||
|
('codec_id', hci.CodingFormat.parse_from_bytes),
|
||||||
|
('codec_specific_configuration', 'v'),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
class ASE_Config_Codec(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.1 - Config Codec Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
target_latency: List[int]
|
||||||
|
target_phy: List[int]
|
||||||
|
codec_id: List[hci.CodingFormat]
|
||||||
|
codec_specific_configuration: List[bytes]
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
('ase_id', 1),
|
||||||
|
('cig_id', 1),
|
||||||
|
('cis_id', 1),
|
||||||
|
('sdu_interval', 3),
|
||||||
|
('framing', 1),
|
||||||
|
('phy', 1),
|
||||||
|
('max_sdu', 2),
|
||||||
|
('retransmission_number', 1),
|
||||||
|
('max_transport_latency', 2),
|
||||||
|
('presentation_delay', 3),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
class ASE_Config_QOS(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.2 - Config Qos Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
cig_id: List[int]
|
||||||
|
cis_id: List[int]
|
||||||
|
sdu_interval: List[int]
|
||||||
|
framing: List[int]
|
||||||
|
phy: List[int]
|
||||||
|
max_sdu: List[int]
|
||||||
|
retransmission_number: List[int]
|
||||||
|
max_transport_latency: List[int]
|
||||||
|
presentation_delay: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||||
|
class ASE_Enable(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.3 - Enable Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
metadata: bytes
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||||
|
class ASE_Receiver_Start_Ready(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||||
|
class ASE_Disable(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.5 - Disable Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||||
|
class ASE_Receiver_Stop_Ready(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||||
|
class ASE_Update_Metadata(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.7 - Update Metadata Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
metadata: List[bytes]
|
||||||
|
|
||||||
|
|
||||||
|
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||||
|
class ASE_Release(ASE_Operation):
|
||||||
|
'''
|
||||||
|
See Audio Stream Control Service 5.8 - Release Operation
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class AseResponseCode(enum.IntEnum):
|
||||||
|
# fmt: off
|
||||||
|
SUCCESS = 0x00
|
||||||
|
UNSUPPORTED_OPCODE = 0x01
|
||||||
|
INVALID_LENGTH = 0x02
|
||||||
|
INVALID_ASE_ID = 0x03
|
||||||
|
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
|
||||||
|
INVALID_ASE_DIRECTION = 0x05
|
||||||
|
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
|
||||||
|
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
|
||||||
|
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
|
||||||
|
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
|
||||||
|
UNSUPPORTED_METADATA = 0x0A
|
||||||
|
REJECTED_METADATA = 0x0B
|
||||||
|
INVALID_METADATA = 0x0C
|
||||||
|
INSUFFICIENT_RESOURCES = 0x0D
|
||||||
|
UNSPECIFIED_ERROR = 0x0E
|
||||||
|
|
||||||
|
|
||||||
|
class AseReasonCode(enum.IntEnum):
|
||||||
|
# fmt: off
|
||||||
|
NONE = 0x00
|
||||||
|
CODEC_ID = 0x01
|
||||||
|
CODEC_SPECIFIC_CONFIGURATION = 0x02
|
||||||
|
SDU_INTERVAL = 0x03
|
||||||
|
FRAMING = 0x04
|
||||||
|
PHY = 0x05
|
||||||
|
MAXIMUM_SDU_SIZE = 0x06
|
||||||
|
RETRANSMISSION_NUMBER = 0x07
|
||||||
|
MAX_TRANSPORT_LATENCY = 0x08
|
||||||
|
PRESENTATION_DELAY = 0x09
|
||||||
|
INVALID_ASE_CIS_MAPPING = 0x0A
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class AudioRole(enum.IntEnum):
|
||||||
|
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||||
|
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class AseStateMachine(gatt.Characteristic):
|
||||||
|
class State(enum.IntEnum):
|
||||||
|
# fmt: off
|
||||||
|
IDLE = 0x00
|
||||||
|
CODEC_CONFIGURED = 0x01
|
||||||
|
QOS_CONFIGURED = 0x02
|
||||||
|
ENABLING = 0x03
|
||||||
|
STREAMING = 0x04
|
||||||
|
DISABLING = 0x05
|
||||||
|
RELEASING = 0x06
|
||||||
|
|
||||||
|
cis_link: Optional[device.CisLink] = None
|
||||||
|
|
||||||
|
# Additional parameters in CODEC_CONFIGURED State
|
||||||
|
preferred_framing = 0 # Unframed PDU supported
|
||||||
|
preferred_phy = 0
|
||||||
|
preferred_retransmission_number = 13
|
||||||
|
preferred_max_transport_latency = 100
|
||||||
|
supported_presentation_delay_min = 0
|
||||||
|
supported_presentation_delay_max = 0
|
||||||
|
preferred_presentation_delay_min = 0
|
||||||
|
preferred_presentation_delay_max = 0
|
||||||
|
codec_id = hci.CodingFormat(hci.CodecID.LC3)
|
||||||
|
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
|
||||||
|
|
||||||
|
# Additional parameters in QOS_CONFIGURED State
|
||||||
|
cig_id = 0
|
||||||
|
cis_id = 0
|
||||||
|
sdu_interval = 0
|
||||||
|
framing = 0
|
||||||
|
phy = 0
|
||||||
|
max_sdu = 0
|
||||||
|
retransmission_number = 0
|
||||||
|
max_transport_latency = 0
|
||||||
|
presentation_delay = 0
|
||||||
|
|
||||||
|
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
||||||
|
metadata = le_audio.Metadata()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: AudioRole,
|
||||||
|
ase_id: int,
|
||||||
|
service: AudioStreamControlService,
|
||||||
|
) -> None:
|
||||||
|
self.service = service
|
||||||
|
self.ase_id = ase_id
|
||||||
|
self._state = AseStateMachine.State.IDLE
|
||||||
|
self.role = role
|
||||||
|
|
||||||
|
uuid = (
|
||||||
|
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||||
|
if role == AudioRole.SINK
|
||||||
|
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
uuid=uuid,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=gatt.CharacteristicValue(read=self.on_read),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.service.device.on('cis_request', self.on_cis_request)
|
||||||
|
self.service.device.on('cis_establishment', self.on_cis_establishment)
|
||||||
|
|
||||||
|
def on_cis_request(
|
||||||
|
self,
|
||||||
|
acl_connection: device.Connection,
|
||||||
|
cis_handle: int,
|
||||||
|
cig_id: int,
|
||||||
|
cis_id: int,
|
||||||
|
) -> None:
|
||||||
|
if (
|
||||||
|
cig_id == self.cig_id
|
||||||
|
and cis_id == self.cis_id
|
||||||
|
and self.state == self.State.ENABLING
|
||||||
|
):
|
||||||
|
acl_connection.abort_on(
|
||||||
|
'flush', self.service.device.accept_cis_request(cis_handle)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
||||||
|
if (
|
||||||
|
cis_link.cig_id == self.cig_id
|
||||||
|
and cis_link.cis_id == self.cis_id
|
||||||
|
and self.state == self.State.ENABLING
|
||||||
|
):
|
||||||
|
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||||
|
|
||||||
|
async def post_cis_established():
|
||||||
|
await self.service.device.send_command(
|
||||||
|
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||||
|
connection_handle=cis_link.handle,
|
||||||
|
data_path_direction=self.role,
|
||||||
|
data_path_id=0x00, # Fixed HCI
|
||||||
|
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||||
|
controller_delay=0,
|
||||||
|
codec_configuration=b'',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.role == AudioRole.SINK:
|
||||||
|
self.state = self.State.STREAMING
|
||||||
|
await self.service.device.notify_subscribers(self, self.value)
|
||||||
|
|
||||||
|
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
||||||
|
self.cis_link = cis_link
|
||||||
|
|
||||||
|
def on_cis_disconnection(self, _reason) -> None:
|
||||||
|
self.cis_link = None
|
||||||
|
|
||||||
|
def on_config_codec(
|
||||||
|
self,
|
||||||
|
target_latency: int,
|
||||||
|
target_phy: int,
|
||||||
|
codec_id: hci.CodingFormat,
|
||||||
|
codec_specific_configuration: bytes,
|
||||||
|
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state not in (
|
||||||
|
self.State.IDLE,
|
||||||
|
self.State.CODEC_CONFIGURED,
|
||||||
|
self.State.QOS_CONFIGURED,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_transport_latency = target_latency
|
||||||
|
self.phy = target_phy
|
||||||
|
self.codec_id = codec_id
|
||||||
|
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||||
|
self.codec_specific_configuration = codec_specific_configuration
|
||||||
|
else:
|
||||||
|
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
|
||||||
|
codec_specific_configuration
|
||||||
|
)
|
||||||
|
|
||||||
|
self.state = self.State.CODEC_CONFIGURED
|
||||||
|
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_config_qos(
|
||||||
|
self,
|
||||||
|
cig_id: int,
|
||||||
|
cis_id: int,
|
||||||
|
sdu_interval: int,
|
||||||
|
framing: int,
|
||||||
|
phy: int,
|
||||||
|
max_sdu: int,
|
||||||
|
retransmission_number: int,
|
||||||
|
max_transport_latency: int,
|
||||||
|
presentation_delay: int,
|
||||||
|
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state not in (
|
||||||
|
AseStateMachine.State.CODEC_CONFIGURED,
|
||||||
|
AseStateMachine.State.QOS_CONFIGURED,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cig_id = cig_id
|
||||||
|
self.cis_id = cis_id
|
||||||
|
self.sdu_interval = sdu_interval
|
||||||
|
self.framing = framing
|
||||||
|
self.phy = phy
|
||||||
|
self.max_sdu = max_sdu
|
||||||
|
self.retransmission_number = retransmission_number
|
||||||
|
self.max_transport_latency = max_transport_latency
|
||||||
|
self.presentation_delay = presentation_delay
|
||||||
|
|
||||||
|
self.state = self.State.QOS_CONFIGURED
|
||||||
|
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state != AseStateMachine.State.QOS_CONFIGURED:
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||||
|
self.state = self.State.ENABLING
|
||||||
|
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state != AseStateMachine.State.ENABLING:
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
self.state = self.State.STREAMING
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state not in (
|
||||||
|
AseStateMachine.State.ENABLING,
|
||||||
|
AseStateMachine.State.STREAMING,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
if self.role == AudioRole.SINK:
|
||||||
|
self.state = self.State.QOS_CONFIGURED
|
||||||
|
else:
|
||||||
|
self.state = self.State.DISABLING
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if (
|
||||||
|
self.role != AudioRole.SOURCE
|
||||||
|
or self.state != AseStateMachine.State.DISABLING
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
self.state = self.State.QOS_CONFIGURED
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_update_metadata(
|
||||||
|
self, metadata: bytes
|
||||||
|
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state not in (
|
||||||
|
AseStateMachine.State.ENABLING,
|
||||||
|
AseStateMachine.State.STREAMING,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||||
|
if self.state == AseStateMachine.State.IDLE:
|
||||||
|
return (
|
||||||
|
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||||
|
AseReasonCode.NONE,
|
||||||
|
)
|
||||||
|
self.state = self.State.RELEASING
|
||||||
|
|
||||||
|
async def remove_cis_async():
|
||||||
|
await self.service.device.send_command(
|
||||||
|
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||||
|
connection_handle=self.cis_link.handle,
|
||||||
|
data_path_direction=self.role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.state = self.State.IDLE
|
||||||
|
await self.service.device.notify_subscribers(self, self.value)
|
||||||
|
|
||||||
|
self.service.device.abort_on('flush', remove_cis_async())
|
||||||
|
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> State:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, new_state: State) -> None:
|
||||||
|
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
||||||
|
self._state = new_state
|
||||||
|
self.emit('state_change')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
|
||||||
|
|
||||||
|
if self.state == self.State.CODEC_CONFIGURED:
|
||||||
|
codec_specific_configuration_bytes = bytes(
|
||||||
|
self.codec_specific_configuration
|
||||||
|
)
|
||||||
|
additional_parameters = (
|
||||||
|
struct.pack(
|
||||||
|
'<BBBH',
|
||||||
|
self.preferred_framing,
|
||||||
|
self.preferred_phy,
|
||||||
|
self.preferred_retransmission_number,
|
||||||
|
self.preferred_max_transport_latency,
|
||||||
|
)
|
||||||
|
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
|
||||||
|
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
|
||||||
|
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
|
||||||
|
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
|
||||||
|
+ bytes(self.codec_id)
|
||||||
|
+ bytes([len(codec_specific_configuration_bytes)])
|
||||||
|
+ codec_specific_configuration_bytes
|
||||||
|
)
|
||||||
|
elif self.state == self.State.QOS_CONFIGURED:
|
||||||
|
additional_parameters = (
|
||||||
|
bytes([self.cig_id, self.cis_id])
|
||||||
|
+ self.sdu_interval.to_bytes(3, 'little')
|
||||||
|
+ struct.pack(
|
||||||
|
'<BBHBH',
|
||||||
|
self.framing,
|
||||||
|
self.phy,
|
||||||
|
self.max_sdu,
|
||||||
|
self.retransmission_number,
|
||||||
|
self.max_transport_latency,
|
||||||
|
)
|
||||||
|
+ self.presentation_delay.to_bytes(3, 'little')
|
||||||
|
)
|
||||||
|
elif self.state in (
|
||||||
|
self.State.ENABLING,
|
||||||
|
self.State.STREAMING,
|
||||||
|
self.State.DISABLING,
|
||||||
|
):
|
||||||
|
metadata_bytes = bytes(self.metadata)
|
||||||
|
additional_parameters = (
|
||||||
|
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_parameters = b''
|
||||||
|
|
||||||
|
return bytes([self.ase_id, self.state]) + additional_parameters
|
||||||
|
|
||||||
|
@value.setter
|
||||||
|
def value(self, _new_value):
|
||||||
|
# Readonly. Do nothing in the setter.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_read(self, _: Optional[device.Connection]) -> bytes:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
|
||||||
|
f'state={self._state.name})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class AudioStreamControlService(gatt.TemplateService):
|
||||||
|
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
|
||||||
|
|
||||||
|
ase_state_machines: Dict[int, AseStateMachine]
|
||||||
|
ase_control_point: gatt.Characteristic
|
||||||
|
_active_client: Optional[device.Connection] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: device.Device,
|
||||||
|
source_ase_id: Sequence[int] = (),
|
||||||
|
sink_ase_id: Sequence[int] = (),
|
||||||
|
) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.ase_state_machines = {
|
||||||
|
**{
|
||||||
|
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
|
||||||
|
for id in sink_ase_id
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
|
||||||
|
for id in source_ase_id
|
||||||
|
},
|
||||||
|
} # ASE state machines, by ASE ID
|
||||||
|
|
||||||
|
self.ase_control_point = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.WRITE
|
||||||
|
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.WRITEABLE,
|
||||||
|
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
|
||||||
|
|
||||||
|
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
|
||||||
|
if ase := self.ase_state_machines.get(ase_id):
|
||||||
|
handler = getattr(ase, 'on_' + opcode.name.lower())
|
||||||
|
return (ase_id, *handler(*args))
|
||||||
|
else:
|
||||||
|
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
|
||||||
|
|
||||||
|
def _on_client_disconnected(self, _reason: int) -> None:
|
||||||
|
for ase in self.ase_state_machines.values():
|
||||||
|
ase.state = AseStateMachine.State.IDLE
|
||||||
|
self._active_client = None
|
||||||
|
|
||||||
|
def on_write_ase_control_point(self, connection, data):
|
||||||
|
if not self._active_client and connection:
|
||||||
|
self._active_client = connection
|
||||||
|
connection.once('disconnection', self._on_client_disconnected)
|
||||||
|
|
||||||
|
operation = ASE_Operation.from_bytes(data)
|
||||||
|
responses = []
|
||||||
|
logger.debug(f'*** ASCS Write {operation} ***')
|
||||||
|
|
||||||
|
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
|
||||||
|
for ase_id, *args in zip(
|
||||||
|
operation.ase_id,
|
||||||
|
operation.target_latency,
|
||||||
|
operation.target_phy,
|
||||||
|
operation.codec_id,
|
||||||
|
operation.codec_specific_configuration,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
|
||||||
|
for ase_id, *args in zip(
|
||||||
|
operation.ase_id,
|
||||||
|
operation.cig_id,
|
||||||
|
operation.cis_id,
|
||||||
|
operation.sdu_interval,
|
||||||
|
operation.framing,
|
||||||
|
operation.phy,
|
||||||
|
operation.max_sdu,
|
||||||
|
operation.retransmission_number,
|
||||||
|
operation.max_transport_latency,
|
||||||
|
operation.presentation_delay,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
elif operation.op_code in (
|
||||||
|
ASE_Operation.Opcode.ENABLE,
|
||||||
|
ASE_Operation.Opcode.UPDATE_METADATA,
|
||||||
|
):
|
||||||
|
for ase_id, *args in zip(
|
||||||
|
operation.ase_id,
|
||||||
|
operation.metadata,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
elif operation.op_code in (
|
||||||
|
ASE_Operation.Opcode.RECEIVER_START_READY,
|
||||||
|
ASE_Operation.Opcode.DISABLE,
|
||||||
|
ASE_Operation.Opcode.RECEIVER_STOP_READY,
|
||||||
|
ASE_Operation.Opcode.RELEASE,
|
||||||
|
):
|
||||||
|
for ase_id in operation.ase_id:
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||||
|
|
||||||
|
control_point_notification = bytes(
|
||||||
|
[operation.op_code, len(responses)]
|
||||||
|
) + b''.join(map(bytes, responses))
|
||||||
|
self.device.abort_on(
|
||||||
|
'flush',
|
||||||
|
self.device.notify_subscribers(
|
||||||
|
self.ase_control_point, control_point_notification
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for ase_id, *_ in responses:
|
||||||
|
if ase := self.ase_state_machines.get(ase_id):
|
||||||
|
self.device.abort_on(
|
||||||
|
'flush',
|
||||||
|
self.device.notify_subscribers(ase, ase.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||||
|
SERVICE_CLASS = AudioStreamControlService
|
||||||
|
|
||||||
|
sink_ase: List[gatt_client.CharacteristicProxy]
|
||||||
|
source_ase: List[gatt_client.CharacteristicProxy]
|
||||||
|
ase_control_point: gatt_client.CharacteristicProxy
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
|
||||||
|
self.sink_ase = service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
self.source_ase = service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
|
||||||
|
)[0]
|
||||||
@@ -24,15 +24,12 @@ import enum
|
|||||||
import struct
|
import struct
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, Union, Type, Dict, Any, Tuple
|
from typing import List
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from bumble import core
|
from bumble import core
|
||||||
from bumble import colors
|
|
||||||
from bumble import device
|
|
||||||
from bumble import hci
|
from bumble import hci
|
||||||
from bumble import gatt
|
from bumble import gatt
|
||||||
from bumble import gatt_client
|
|
||||||
from bumble import utils
|
from bumble import utils
|
||||||
from bumble.profiles import le_audio
|
from bumble.profiles import le_audio
|
||||||
|
|
||||||
@@ -251,231 +248,6 @@ class AnnouncementType(utils.OpenIntEnum):
|
|||||||
TARGETED = 0x01
|
TARGETED = 0x01
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# ASE Operations
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ASE_Operation:
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service - 5 ASE Control operations.
|
|
||||||
'''
|
|
||||||
|
|
||||||
classes: Dict[int, Type[ASE_Operation]] = {}
|
|
||||||
op_code: int
|
|
||||||
name: str
|
|
||||||
fields: Optional[Sequence[Any]] = None
|
|
||||||
ase_id: List[int]
|
|
||||||
|
|
||||||
class Opcode(enum.IntEnum):
|
|
||||||
# fmt: off
|
|
||||||
CONFIG_CODEC = 0x01
|
|
||||||
CONFIG_QOS = 0x02
|
|
||||||
ENABLE = 0x03
|
|
||||||
RECEIVER_START_READY = 0x04
|
|
||||||
DISABLE = 0x05
|
|
||||||
RECEIVER_STOP_READY = 0x06
|
|
||||||
UPDATE_METADATA = 0x07
|
|
||||||
RELEASE = 0x08
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(pdu: bytes) -> ASE_Operation:
|
|
||||||
op_code = pdu[0]
|
|
||||||
|
|
||||||
cls = ASE_Operation.classes.get(op_code)
|
|
||||||
if cls is None:
|
|
||||||
instance = ASE_Operation(pdu)
|
|
||||||
instance.name = ASE_Operation.Opcode(op_code).name
|
|
||||||
instance.op_code = op_code
|
|
||||||
return instance
|
|
||||||
self = cls.__new__(cls)
|
|
||||||
ASE_Operation.__init__(self, pdu)
|
|
||||||
if self.fields is not None:
|
|
||||||
self.init_from_bytes(pdu, 1)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def subclass(fields):
|
|
||||||
def inner(cls: Type[ASE_Operation]):
|
|
||||||
try:
|
|
||||||
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
|
|
||||||
cls.name = operation.name
|
|
||||||
cls.op_code = operation
|
|
||||||
except:
|
|
||||||
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
|
|
||||||
cls.fields = fields
|
|
||||||
|
|
||||||
# Register a factory for this class
|
|
||||||
ASE_Operation.classes[cls.op_code] = cls
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|
||||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
|
|
||||||
if self.fields is not None and kwargs:
|
|
||||||
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
|
|
||||||
if pdu is None:
|
|
||||||
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
|
||||||
kwargs, self.fields
|
|
||||||
)
|
|
||||||
self.pdu = pdu
|
|
||||||
|
|
||||||
def init_from_bytes(self, pdu: bytes, offset: int):
|
|
||||||
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
|
||||||
return self.pdu
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
result = f'{colors.color(self.name, "yellow")} '
|
|
||||||
if fields := getattr(self, 'fields', None):
|
|
||||||
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
|
|
||||||
else:
|
|
||||||
if len(self.pdu) > 1:
|
|
||||||
result += f': {self.pdu.hex()}'
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
('ase_id', 1),
|
|
||||||
('target_latency', 1),
|
|
||||||
('target_phy', 1),
|
|
||||||
('codec_id', hci.CodingFormat.parse_from_bytes),
|
|
||||||
('codec_specific_configuration', 'v'),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
class ASE_Config_Codec(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.1 - Config Codec Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
target_latency: List[int]
|
|
||||||
target_phy: List[int]
|
|
||||||
codec_id: List[hci.CodingFormat]
|
|
||||||
codec_specific_configuration: List[bytes]
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
('ase_id', 1),
|
|
||||||
('cig_id', 1),
|
|
||||||
('cis_id', 1),
|
|
||||||
('sdu_interval', 3),
|
|
||||||
('framing', 1),
|
|
||||||
('phy', 1),
|
|
||||||
('max_sdu', 2),
|
|
||||||
('retransmission_number', 1),
|
|
||||||
('max_transport_latency', 2),
|
|
||||||
('presentation_delay', 3),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
class ASE_Config_QOS(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.2 - Config Qos Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
cig_id: List[int]
|
|
||||||
cis_id: List[int]
|
|
||||||
sdu_interval: List[int]
|
|
||||||
framing: List[int]
|
|
||||||
phy: List[int]
|
|
||||||
max_sdu: List[int]
|
|
||||||
retransmission_number: List[int]
|
|
||||||
max_transport_latency: List[int]
|
|
||||||
presentation_delay: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
|
||||||
class ASE_Enable(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.3 - Enable Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
metadata: bytes
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
|
||||||
class ASE_Receiver_Start_Ready(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
|
||||||
class ASE_Disable(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.5 - Disable Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
|
||||||
class ASE_Receiver_Stop_Ready(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
|
||||||
class ASE_Update_Metadata(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.7 - Update Metadata Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
metadata: List[bytes]
|
|
||||||
|
|
||||||
|
|
||||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
|
||||||
class ASE_Release(ASE_Operation):
|
|
||||||
'''
|
|
||||||
See Audio Stream Control Service 5.8 - Release Operation
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
class AseResponseCode(enum.IntEnum):
|
|
||||||
# fmt: off
|
|
||||||
SUCCESS = 0x00
|
|
||||||
UNSUPPORTED_OPCODE = 0x01
|
|
||||||
INVALID_LENGTH = 0x02
|
|
||||||
INVALID_ASE_ID = 0x03
|
|
||||||
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
|
|
||||||
INVALID_ASE_DIRECTION = 0x05
|
|
||||||
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
|
|
||||||
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
|
|
||||||
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
|
|
||||||
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
|
|
||||||
UNSUPPORTED_METADATA = 0x0A
|
|
||||||
REJECTED_METADATA = 0x0B
|
|
||||||
INVALID_METADATA = 0x0C
|
|
||||||
INSUFFICIENT_RESOURCES = 0x0D
|
|
||||||
UNSPECIFIED_ERROR = 0x0E
|
|
||||||
|
|
||||||
|
|
||||||
class AseReasonCode(enum.IntEnum):
|
|
||||||
# fmt: off
|
|
||||||
NONE = 0x00
|
|
||||||
CODEC_ID = 0x01
|
|
||||||
CODEC_SPECIFIC_CONFIGURATION = 0x02
|
|
||||||
SDU_INTERVAL = 0x03
|
|
||||||
FRAMING = 0x04
|
|
||||||
PHY = 0x05
|
|
||||||
MAXIMUM_SDU_SIZE = 0x06
|
|
||||||
RETRANSMISSION_NUMBER = 0x07
|
|
||||||
MAX_TRANSPORT_LATENCY = 0x08
|
|
||||||
PRESENTATION_DELAY = 0x09
|
|
||||||
INVALID_ASE_CIS_MAPPING = 0x0A
|
|
||||||
|
|
||||||
|
|
||||||
class AudioRole(enum.IntEnum):
|
|
||||||
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
|
||||||
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UnicastServerAdvertisingData:
|
class UnicastServerAdvertisingData:
|
||||||
"""Advertising Data for ASCS."""
|
"""Advertising Data for ASCS."""
|
||||||
@@ -683,51 +455,6 @@ class CodecSpecificConfiguration:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PacRecord:
|
|
||||||
coding_format: hci.CodingFormat
|
|
||||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
|
||||||
# TODO: Parse Metadata
|
|
||||||
metadata: bytes = b''
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, data: bytes) -> PacRecord:
|
|
||||||
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
|
|
||||||
codec_specific_capabilities_size = data[offset]
|
|
||||||
|
|
||||||
offset += 1
|
|
||||||
codec_specific_capabilities_bytes = data[
|
|
||||||
offset : offset + codec_specific_capabilities_size
|
|
||||||
]
|
|
||||||
offset += codec_specific_capabilities_size
|
|
||||||
metadata_size = data[offset]
|
|
||||||
metadata = data[offset : offset + metadata_size]
|
|
||||||
|
|
||||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
|
||||||
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
|
||||||
codec_specific_capabilities = codec_specific_capabilities_bytes
|
|
||||||
else:
|
|
||||||
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
|
|
||||||
codec_specific_capabilities_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
return PacRecord(
|
|
||||||
coding_format=coding_format,
|
|
||||||
codec_specific_capabilities=codec_specific_capabilities,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
|
||||||
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
|
||||||
return (
|
|
||||||
bytes(self.coding_format)
|
|
||||||
+ bytes([len(capabilities_bytes)])
|
|
||||||
+ capabilities_bytes
|
|
||||||
+ bytes([len(self.metadata)])
|
|
||||||
+ self.metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BroadcastAudioAnnouncement:
|
class BroadcastAudioAnnouncement:
|
||||||
broadcast_id: int
|
broadcast_id: int
|
||||||
@@ -819,603 +546,3 @@ class BasicAudioAnnouncement:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cls(presentation_delay, subgroups)
|
return cls(presentation_delay, subgroups)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Server
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
|
||||||
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
|
|
||||||
|
|
||||||
sink_pac: Optional[gatt.Characteristic]
|
|
||||||
sink_audio_locations: Optional[gatt.Characteristic]
|
|
||||||
source_pac: Optional[gatt.Characteristic]
|
|
||||||
source_audio_locations: Optional[gatt.Characteristic]
|
|
||||||
available_audio_contexts: gatt.Characteristic
|
|
||||||
supported_audio_contexts: gatt.Characteristic
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
supported_source_context: ContextType,
|
|
||||||
supported_sink_context: ContextType,
|
|
||||||
available_source_context: ContextType,
|
|
||||||
available_sink_context: ContextType,
|
|
||||||
sink_pac: Sequence[PacRecord] = (),
|
|
||||||
sink_audio_locations: Optional[AudioLocation] = None,
|
|
||||||
source_pac: Sequence[PacRecord] = (),
|
|
||||||
source_audio_locations: Optional[AudioLocation] = None,
|
|
||||||
) -> None:
|
|
||||||
characteristics = []
|
|
||||||
|
|
||||||
self.supported_audio_contexts = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=struct.pack('<HH', supported_sink_context, supported_source_context),
|
|
||||||
)
|
|
||||||
characteristics.append(self.supported_audio_contexts)
|
|
||||||
|
|
||||||
self.available_audio_contexts = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ
|
|
||||||
| gatt.Characteristic.Properties.NOTIFY,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=struct.pack('<HH', available_sink_context, available_source_context),
|
|
||||||
)
|
|
||||||
characteristics.append(self.available_audio_contexts)
|
|
||||||
|
|
||||||
if sink_pac:
|
|
||||||
self.sink_pac = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
|
|
||||||
)
|
|
||||||
characteristics.append(self.sink_pac)
|
|
||||||
|
|
||||||
if sink_audio_locations is not None:
|
|
||||||
self.sink_audio_locations = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=struct.pack('<I', sink_audio_locations),
|
|
||||||
)
|
|
||||||
characteristics.append(self.sink_audio_locations)
|
|
||||||
|
|
||||||
if source_pac:
|
|
||||||
self.source_pac = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
|
|
||||||
)
|
|
||||||
characteristics.append(self.source_pac)
|
|
||||||
|
|
||||||
if source_audio_locations is not None:
|
|
||||||
self.source_audio_locations = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.READ,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=struct.pack('<I', source_audio_locations),
|
|
||||||
)
|
|
||||||
characteristics.append(self.source_audio_locations)
|
|
||||||
|
|
||||||
super().__init__(characteristics)
|
|
||||||
|
|
||||||
|
|
||||||
class AseStateMachine(gatt.Characteristic):
|
|
||||||
class State(enum.IntEnum):
|
|
||||||
# fmt: off
|
|
||||||
IDLE = 0x00
|
|
||||||
CODEC_CONFIGURED = 0x01
|
|
||||||
QOS_CONFIGURED = 0x02
|
|
||||||
ENABLING = 0x03
|
|
||||||
STREAMING = 0x04
|
|
||||||
DISABLING = 0x05
|
|
||||||
RELEASING = 0x06
|
|
||||||
|
|
||||||
cis_link: Optional[device.CisLink] = None
|
|
||||||
|
|
||||||
# Additional parameters in CODEC_CONFIGURED State
|
|
||||||
preferred_framing = 0 # Unframed PDU supported
|
|
||||||
preferred_phy = 0
|
|
||||||
preferred_retransmission_number = 13
|
|
||||||
preferred_max_transport_latency = 100
|
|
||||||
supported_presentation_delay_min = 0
|
|
||||||
supported_presentation_delay_max = 0
|
|
||||||
preferred_presentation_delay_min = 0
|
|
||||||
preferred_presentation_delay_max = 0
|
|
||||||
codec_id = hci.CodingFormat(hci.CodecID.LC3)
|
|
||||||
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
|
|
||||||
|
|
||||||
# Additional parameters in QOS_CONFIGURED State
|
|
||||||
cig_id = 0
|
|
||||||
cis_id = 0
|
|
||||||
sdu_interval = 0
|
|
||||||
framing = 0
|
|
||||||
phy = 0
|
|
||||||
max_sdu = 0
|
|
||||||
retransmission_number = 0
|
|
||||||
max_transport_latency = 0
|
|
||||||
presentation_delay = 0
|
|
||||||
|
|
||||||
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
|
||||||
# TODO: Parse this
|
|
||||||
metadata = b''
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
role: AudioRole,
|
|
||||||
ase_id: int,
|
|
||||||
service: AudioStreamControlService,
|
|
||||||
) -> None:
|
|
||||||
self.service = service
|
|
||||||
self.ase_id = ase_id
|
|
||||||
self._state = AseStateMachine.State.IDLE
|
|
||||||
self.role = role
|
|
||||||
|
|
||||||
uuid = (
|
|
||||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
|
||||||
if role == AudioRole.SINK
|
|
||||||
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
uuid=uuid,
|
|
||||||
properties=gatt.Characteristic.Properties.READ
|
|
||||||
| gatt.Characteristic.Properties.NOTIFY,
|
|
||||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
|
||||||
value=gatt.CharacteristicValue(read=self.on_read),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.service.device.on('cis_request', self.on_cis_request)
|
|
||||||
self.service.device.on('cis_establishment', self.on_cis_establishment)
|
|
||||||
|
|
||||||
def on_cis_request(
|
|
||||||
self,
|
|
||||||
acl_connection: device.Connection,
|
|
||||||
cis_handle: int,
|
|
||||||
cig_id: int,
|
|
||||||
cis_id: int,
|
|
||||||
) -> None:
|
|
||||||
if (
|
|
||||||
cig_id == self.cig_id
|
|
||||||
and cis_id == self.cis_id
|
|
||||||
and self.state == self.State.ENABLING
|
|
||||||
):
|
|
||||||
acl_connection.abort_on(
|
|
||||||
'flush', self.service.device.accept_cis_request(cis_handle)
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
|
||||||
if (
|
|
||||||
cis_link.cig_id == self.cig_id
|
|
||||||
and cis_link.cis_id == self.cis_id
|
|
||||||
and self.state == self.State.ENABLING
|
|
||||||
):
|
|
||||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
|
||||||
|
|
||||||
async def post_cis_established():
|
|
||||||
await self.service.device.send_command(
|
|
||||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
|
||||||
connection_handle=cis_link.handle,
|
|
||||||
data_path_direction=self.role,
|
|
||||||
data_path_id=0x00, # Fixed HCI
|
|
||||||
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
|
||||||
controller_delay=0,
|
|
||||||
codec_configuration=b'',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if self.role == AudioRole.SINK:
|
|
||||||
self.state = self.State.STREAMING
|
|
||||||
await self.service.device.notify_subscribers(self, self.value)
|
|
||||||
|
|
||||||
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
|
||||||
self.cis_link = cis_link
|
|
||||||
|
|
||||||
def on_cis_disconnection(self, _reason) -> None:
|
|
||||||
self.cis_link = None
|
|
||||||
|
|
||||||
def on_config_codec(
|
|
||||||
self,
|
|
||||||
target_latency: int,
|
|
||||||
target_phy: int,
|
|
||||||
codec_id: hci.CodingFormat,
|
|
||||||
codec_specific_configuration: bytes,
|
|
||||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state not in (
|
|
||||||
self.State.IDLE,
|
|
||||||
self.State.CODEC_CONFIGURED,
|
|
||||||
self.State.QOS_CONFIGURED,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.max_transport_latency = target_latency
|
|
||||||
self.phy = target_phy
|
|
||||||
self.codec_id = codec_id
|
|
||||||
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
|
||||||
self.codec_specific_configuration = codec_specific_configuration
|
|
||||||
else:
|
|
||||||
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
|
|
||||||
codec_specific_configuration
|
|
||||||
)
|
|
||||||
|
|
||||||
self.state = self.State.CODEC_CONFIGURED
|
|
||||||
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_config_qos(
|
|
||||||
self,
|
|
||||||
cig_id: int,
|
|
||||||
cis_id: int,
|
|
||||||
sdu_interval: int,
|
|
||||||
framing: int,
|
|
||||||
phy: int,
|
|
||||||
max_sdu: int,
|
|
||||||
retransmission_number: int,
|
|
||||||
max_transport_latency: int,
|
|
||||||
presentation_delay: int,
|
|
||||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state not in (
|
|
||||||
AseStateMachine.State.CODEC_CONFIGURED,
|
|
||||||
AseStateMachine.State.QOS_CONFIGURED,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cig_id = cig_id
|
|
||||||
self.cis_id = cis_id
|
|
||||||
self.sdu_interval = sdu_interval
|
|
||||||
self.framing = framing
|
|
||||||
self.phy = phy
|
|
||||||
self.max_sdu = max_sdu
|
|
||||||
self.retransmission_number = retransmission_number
|
|
||||||
self.max_transport_latency = max_transport_latency
|
|
||||||
self.presentation_delay = presentation_delay
|
|
||||||
|
|
||||||
self.state = self.State.QOS_CONFIGURED
|
|
||||||
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state != AseStateMachine.State.QOS_CONFIGURED:
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.metadata = metadata
|
|
||||||
self.state = self.State.ENABLING
|
|
||||||
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state != AseStateMachine.State.ENABLING:
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
self.state = self.State.STREAMING
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state not in (
|
|
||||||
AseStateMachine.State.ENABLING,
|
|
||||||
AseStateMachine.State.STREAMING,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
if self.role == AudioRole.SINK:
|
|
||||||
self.state = self.State.QOS_CONFIGURED
|
|
||||||
else:
|
|
||||||
self.state = self.State.DISABLING
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if (
|
|
||||||
self.role != AudioRole.SOURCE
|
|
||||||
or self.state != AseStateMachine.State.DISABLING
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
self.state = self.State.QOS_CONFIGURED
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_update_metadata(
|
|
||||||
self, metadata: bytes
|
|
||||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state not in (
|
|
||||||
AseStateMachine.State.ENABLING,
|
|
||||||
AseStateMachine.State.STREAMING,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
self.metadata = metadata
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
|
||||||
if self.state == AseStateMachine.State.IDLE:
|
|
||||||
return (
|
|
||||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
|
||||||
AseReasonCode.NONE,
|
|
||||||
)
|
|
||||||
self.state = self.State.RELEASING
|
|
||||||
|
|
||||||
async def remove_cis_async():
|
|
||||||
await self.service.device.send_command(
|
|
||||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
|
||||||
connection_handle=self.cis_link.handle,
|
|
||||||
data_path_direction=self.role,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.state = self.State.IDLE
|
|
||||||
await self.service.device.notify_subscribers(self, self.value)
|
|
||||||
|
|
||||||
self.service.device.abort_on('flush', remove_cis_async())
|
|
||||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> State:
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
@state.setter
|
|
||||||
def state(self, new_state: State) -> None:
|
|
||||||
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
|
||||||
self._state = new_state
|
|
||||||
self.emit('state_change')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
|
|
||||||
|
|
||||||
if self.state == self.State.CODEC_CONFIGURED:
|
|
||||||
codec_specific_configuration_bytes = bytes(
|
|
||||||
self.codec_specific_configuration
|
|
||||||
)
|
|
||||||
additional_parameters = (
|
|
||||||
struct.pack(
|
|
||||||
'<BBBH',
|
|
||||||
self.preferred_framing,
|
|
||||||
self.preferred_phy,
|
|
||||||
self.preferred_retransmission_number,
|
|
||||||
self.preferred_max_transport_latency,
|
|
||||||
)
|
|
||||||
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
|
|
||||||
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
|
|
||||||
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
|
|
||||||
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
|
|
||||||
+ bytes(self.codec_id)
|
|
||||||
+ bytes([len(codec_specific_configuration_bytes)])
|
|
||||||
+ codec_specific_configuration_bytes
|
|
||||||
)
|
|
||||||
elif self.state == self.State.QOS_CONFIGURED:
|
|
||||||
additional_parameters = (
|
|
||||||
bytes([self.cig_id, self.cis_id])
|
|
||||||
+ self.sdu_interval.to_bytes(3, 'little')
|
|
||||||
+ struct.pack(
|
|
||||||
'<BBHBH',
|
|
||||||
self.framing,
|
|
||||||
self.phy,
|
|
||||||
self.max_sdu,
|
|
||||||
self.retransmission_number,
|
|
||||||
self.max_transport_latency,
|
|
||||||
)
|
|
||||||
+ self.presentation_delay.to_bytes(3, 'little')
|
|
||||||
)
|
|
||||||
elif self.state in (
|
|
||||||
self.State.ENABLING,
|
|
||||||
self.State.STREAMING,
|
|
||||||
self.State.DISABLING,
|
|
||||||
):
|
|
||||||
additional_parameters = (
|
|
||||||
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
additional_parameters = b''
|
|
||||||
|
|
||||||
return bytes([self.ase_id, self.state]) + additional_parameters
|
|
||||||
|
|
||||||
@value.setter
|
|
||||||
def value(self, _new_value):
|
|
||||||
# Readonly. Do nothing in the setter.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_read(self, _: Optional[device.Connection]) -> bytes:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return (
|
|
||||||
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
|
|
||||||
f'state={self._state.name})'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamControlService(gatt.TemplateService):
|
|
||||||
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
|
|
||||||
|
|
||||||
ase_state_machines: Dict[int, AseStateMachine]
|
|
||||||
ase_control_point: gatt.Characteristic
|
|
||||||
_active_client: Optional[device.Connection] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: device.Device,
|
|
||||||
source_ase_id: Sequence[int] = [],
|
|
||||||
sink_ase_id: Sequence[int] = [],
|
|
||||||
) -> None:
|
|
||||||
self.device = device
|
|
||||||
self.ase_state_machines = {
|
|
||||||
**{
|
|
||||||
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
|
|
||||||
for id in sink_ase_id
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
|
|
||||||
for id in source_ase_id
|
|
||||||
},
|
|
||||||
} # ASE state machines, by ASE ID
|
|
||||||
|
|
||||||
self.ase_control_point = gatt.Characteristic(
|
|
||||||
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
|
|
||||||
properties=gatt.Characteristic.Properties.WRITE
|
|
||||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
|
||||||
| gatt.Characteristic.Properties.NOTIFY,
|
|
||||||
permissions=gatt.Characteristic.Permissions.WRITEABLE,
|
|
||||||
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
|
|
||||||
|
|
||||||
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
|
|
||||||
if ase := self.ase_state_machines.get(ase_id):
|
|
||||||
handler = getattr(ase, 'on_' + opcode.name.lower())
|
|
||||||
return (ase_id, *handler(*args))
|
|
||||||
else:
|
|
||||||
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
|
|
||||||
|
|
||||||
def _on_client_disconnected(self, _reason: int) -> None:
|
|
||||||
for ase in self.ase_state_machines.values():
|
|
||||||
ase.state = AseStateMachine.State.IDLE
|
|
||||||
self._active_client = None
|
|
||||||
|
|
||||||
def on_write_ase_control_point(self, connection, data):
|
|
||||||
if not self._active_client and connection:
|
|
||||||
self._active_client = connection
|
|
||||||
connection.once('disconnection', self._on_client_disconnected)
|
|
||||||
|
|
||||||
operation = ASE_Operation.from_bytes(data)
|
|
||||||
responses = []
|
|
||||||
logger.debug(f'*** ASCS Write {operation} ***')
|
|
||||||
|
|
||||||
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
|
|
||||||
for ase_id, *args in zip(
|
|
||||||
operation.ase_id,
|
|
||||||
operation.target_latency,
|
|
||||||
operation.target_phy,
|
|
||||||
operation.codec_id,
|
|
||||||
operation.codec_specific_configuration,
|
|
||||||
):
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
|
||||||
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
|
|
||||||
for ase_id, *args in zip(
|
|
||||||
operation.ase_id,
|
|
||||||
operation.cig_id,
|
|
||||||
operation.cis_id,
|
|
||||||
operation.sdu_interval,
|
|
||||||
operation.framing,
|
|
||||||
operation.phy,
|
|
||||||
operation.max_sdu,
|
|
||||||
operation.retransmission_number,
|
|
||||||
operation.max_transport_latency,
|
|
||||||
operation.presentation_delay,
|
|
||||||
):
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
|
||||||
elif operation.op_code in (
|
|
||||||
ASE_Operation.Opcode.ENABLE,
|
|
||||||
ASE_Operation.Opcode.UPDATE_METADATA,
|
|
||||||
):
|
|
||||||
for ase_id, *args in zip(
|
|
||||||
operation.ase_id,
|
|
||||||
operation.metadata,
|
|
||||||
):
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
|
||||||
elif operation.op_code in (
|
|
||||||
ASE_Operation.Opcode.RECEIVER_START_READY,
|
|
||||||
ASE_Operation.Opcode.DISABLE,
|
|
||||||
ASE_Operation.Opcode.RECEIVER_STOP_READY,
|
|
||||||
ASE_Operation.Opcode.RELEASE,
|
|
||||||
):
|
|
||||||
for ase_id in operation.ase_id:
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
|
||||||
|
|
||||||
control_point_notification = bytes(
|
|
||||||
[operation.op_code, len(responses)]
|
|
||||||
) + b''.join(map(bytes, responses))
|
|
||||||
self.device.abort_on(
|
|
||||||
'flush',
|
|
||||||
self.device.notify_subscribers(
|
|
||||||
self.ase_control_point, control_point_notification
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
for ase_id, *_ in responses:
|
|
||||||
if ase := self.ase_state_machines.get(ase_id):
|
|
||||||
self.device.abort_on(
|
|
||||||
'flush',
|
|
||||||
self.device.notify_subscribers(ase, ase.value),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Client
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
|
||||||
SERVICE_CLASS = PublishedAudioCapabilitiesService
|
|
||||||
|
|
||||||
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
|
|
||||||
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
|
||||||
source_pac: Optional[gatt_client.CharacteristicProxy] = None
|
|
||||||
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
|
||||||
available_audio_contexts: gatt_client.CharacteristicProxy
|
|
||||||
supported_audio_contexts: gatt_client.CharacteristicProxy
|
|
||||||
|
|
||||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
|
||||||
self.service_proxy = service_proxy
|
|
||||||
|
|
||||||
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
|
||||||
)[0]
|
|
||||||
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SINK_PAC_CHARACTERISTIC
|
|
||||||
):
|
|
||||||
self.sink_pac = characteristics[0]
|
|
||||||
|
|
||||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
|
|
||||||
):
|
|
||||||
self.source_pac = characteristics[0]
|
|
||||||
|
|
||||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
|
|
||||||
):
|
|
||||||
self.sink_audio_locations = characteristics[0]
|
|
||||||
|
|
||||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
|
||||||
):
|
|
||||||
self.source_audio_locations = characteristics[0]
|
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
|
|
||||||
SERVICE_CLASS = AudioStreamControlService
|
|
||||||
|
|
||||||
sink_ase: List[gatt_client.CharacteristicProxy]
|
|
||||||
source_ase: List[gatt_client.CharacteristicProxy]
|
|
||||||
ase_control_point: gatt_client.CharacteristicProxy
|
|
||||||
|
|
||||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
|
||||||
self.service_proxy = service_proxy
|
|
||||||
|
|
||||||
self.sink_ase = service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
|
||||||
)
|
|
||||||
self.source_ase = service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
|
||||||
)
|
|
||||||
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
|
|
||||||
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
|
|
||||||
)[0]
|
|
||||||
|
|||||||
440
bumble/profiles/bass.py
Normal file
440
bumble/profiles/bass.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
# Copyright 2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for
|
||||||
|
|
||||||
|
"""LE Audio - Broadcast Audio Scan Service"""
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from typing import ClassVar, List, Optional, Sequence
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
|
from bumble import device
|
||||||
|
from bumble import gatt
|
||||||
|
from bumble import gatt_client
|
||||||
|
from bumble import hci
|
||||||
|
from bumble import utils
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class ApplicationError(utils.OpenIntEnum):
|
||||||
|
OPCODE_NOT_SUPPORTED = 0x80
|
||||||
|
INVALID_SOURCE_ID = 0x81
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
|
||||||
|
return bytes([len(subgroups)]) + b"".join(
|
||||||
|
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
|
||||||
|
+ subgroup.metadata
|
||||||
|
for subgroup in subgroups
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
|
||||||
|
num_subgroups = data[0]
|
||||||
|
offset = 1
|
||||||
|
subgroups = []
|
||||||
|
for _ in range(num_subgroups):
|
||||||
|
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
|
||||||
|
metadata_length = data[offset + 4]
|
||||||
|
metadata = data[offset + 5 : offset + 5 + metadata_length]
|
||||||
|
offset += 5 + metadata_length
|
||||||
|
subgroups.append(SubgroupInfo(bis_sync, metadata))
|
||||||
|
|
||||||
|
return subgroups
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
|
||||||
|
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
|
||||||
|
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
|
||||||
|
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SubgroupInfo:
|
||||||
|
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
|
||||||
|
|
||||||
|
bis_sync: int
|
||||||
|
metadata: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class ControlPointOperation:
|
||||||
|
class OpCode(utils.OpenIntEnum):
|
||||||
|
REMOTE_SCAN_STOPPED = 0x00
|
||||||
|
REMOTE_SCAN_STARTED = 0x01
|
||||||
|
ADD_SOURCE = 0x02
|
||||||
|
MODIFY_SOURCE = 0x03
|
||||||
|
SET_BROADCAST_CODE = 0x04
|
||||||
|
REMOVE_SOURCE = 0x05
|
||||||
|
|
||||||
|
op_code: OpCode
|
||||||
|
parameters: bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> ControlPointOperation:
|
||||||
|
op_code = data[0]
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
|
||||||
|
return RemoteScanStoppedOperation()
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
|
||||||
|
return RemoteScanStartedOperation()
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.ADD_SOURCE:
|
||||||
|
return AddSourceOperation.from_parameters(data[1:])
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.MODIFY_SOURCE:
|
||||||
|
return ModifySourceOperation.from_parameters(data[1:])
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.SET_BROADCAST_CODE:
|
||||||
|
return SetBroadcastCodeOperation.from_parameters(data[1:])
|
||||||
|
|
||||||
|
if op_code == cls.OpCode.REMOVE_SOURCE:
|
||||||
|
return RemoveSourceOperation.from_parameters(data[1:])
|
||||||
|
|
||||||
|
raise core.InvalidArgumentError("invalid op code")
|
||||||
|
|
||||||
|
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
|
||||||
|
self.op_code = op_code
|
||||||
|
self.parameters = parameters
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return bytes([self.op_code]) + self.parameters
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteScanStoppedOperation(ControlPointOperation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteScanStartedOperation(ControlPointOperation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
|
||||||
|
|
||||||
|
|
||||||
|
class AddSourceOperation(ControlPointOperation):
|
||||||
|
@classmethod
|
||||||
|
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
|
||||||
|
instance = cls.__new__(cls)
|
||||||
|
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
|
||||||
|
instance.parameters = parameters
|
||||||
|
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
|
||||||
|
parameters, 1
|
||||||
|
)[1]
|
||||||
|
instance.advertising_sid = parameters[7]
|
||||||
|
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
|
||||||
|
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
|
||||||
|
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
|
||||||
|
instance.subgroups = decode_subgroups(parameters[14:])
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
advertiser_address: hci.Address,
|
||||||
|
advertising_sid: int,
|
||||||
|
broadcast_id: int,
|
||||||
|
pa_sync: PeriodicAdvertisingSyncParams,
|
||||||
|
pa_interval: int,
|
||||||
|
subgroups: Sequence[SubgroupInfo],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ControlPointOperation.OpCode.ADD_SOURCE,
|
||||||
|
struct.pack(
|
||||||
|
"<B6sB3sBH",
|
||||||
|
advertiser_address.address_type,
|
||||||
|
bytes(advertiser_address),
|
||||||
|
advertising_sid,
|
||||||
|
broadcast_id.to_bytes(3, "little"),
|
||||||
|
pa_sync,
|
||||||
|
pa_interval,
|
||||||
|
)
|
||||||
|
+ encode_subgroups(subgroups),
|
||||||
|
)
|
||||||
|
self.advertiser_address = advertiser_address
|
||||||
|
self.advertising_sid = advertising_sid
|
||||||
|
self.broadcast_id = broadcast_id
|
||||||
|
self.pa_sync = pa_sync
|
||||||
|
self.pa_interval = pa_interval
|
||||||
|
self.subgroups = list(subgroups)
|
||||||
|
|
||||||
|
|
||||||
|
class ModifySourceOperation(ControlPointOperation):
|
||||||
|
@classmethod
|
||||||
|
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
|
||||||
|
instance = cls.__new__(cls)
|
||||||
|
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
|
||||||
|
instance.parameters = parameters
|
||||||
|
instance.source_id = parameters[0]
|
||||||
|
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
|
||||||
|
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
|
||||||
|
instance.subgroups = decode_subgroups(parameters[4:])
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_id: int,
|
||||||
|
pa_sync: PeriodicAdvertisingSyncParams,
|
||||||
|
pa_interval: int,
|
||||||
|
subgroups: Sequence[SubgroupInfo],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ControlPointOperation.OpCode.MODIFY_SOURCE,
|
||||||
|
struct.pack("<BBH", source_id, pa_sync, pa_interval)
|
||||||
|
+ encode_subgroups(subgroups),
|
||||||
|
)
|
||||||
|
self.source_id = source_id
|
||||||
|
self.pa_sync = pa_sync
|
||||||
|
self.pa_interval = pa_interval
|
||||||
|
self.subgroups = list(subgroups)
|
||||||
|
|
||||||
|
|
||||||
|
class SetBroadcastCodeOperation(ControlPointOperation):
|
||||||
|
@classmethod
|
||||||
|
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
|
||||||
|
instance = cls.__new__(cls)
|
||||||
|
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
|
||||||
|
instance.parameters = parameters
|
||||||
|
instance.source_id = parameters[0]
|
||||||
|
instance.broadcast_code = parameters[1:17]
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_id: int,
|
||||||
|
broadcast_code: bytes,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
|
||||||
|
bytes([source_id]) + broadcast_code,
|
||||||
|
)
|
||||||
|
self.source_id = source_id
|
||||||
|
self.broadcast_code = broadcast_code
|
||||||
|
|
||||||
|
if len(self.broadcast_code) != 16:
|
||||||
|
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveSourceOperation(ControlPointOperation):
|
||||||
|
@classmethod
|
||||||
|
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
|
||||||
|
instance = cls.__new__(cls)
|
||||||
|
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
|
||||||
|
instance.parameters = parameters
|
||||||
|
instance.source_id = parameters[0]
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __init__(self, source_id: int) -> None:
|
||||||
|
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
|
||||||
|
self.source_id = source_id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BroadcastReceiveState:
|
||||||
|
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
|
||||||
|
NOT_SYNCHRONIZED_TO_PA = 0x00
|
||||||
|
SYNCINFO_REQUEST = 0x01
|
||||||
|
SYNCHRONIZED_TO_PA = 0x02
|
||||||
|
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
|
||||||
|
NO_PAST = 0x04
|
||||||
|
|
||||||
|
class BigEncryption(utils.OpenIntEnum):
|
||||||
|
NOT_ENCRYPTED = 0x00
|
||||||
|
BROADCAST_CODE_REQUIRED = 0x01
|
||||||
|
DECRYPTING = 0x02
|
||||||
|
BAD_CODE = 0x03
|
||||||
|
|
||||||
|
source_id: int
|
||||||
|
source_address: hci.Address
|
||||||
|
source_adv_sid: int
|
||||||
|
broadcast_id: int
|
||||||
|
pa_sync_state: PeriodicAdvertisingSyncState
|
||||||
|
big_encryption: BigEncryption
|
||||||
|
bad_code: bytes
|
||||||
|
subgroups: List[SubgroupInfo]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
source_id = data[0]
|
||||||
|
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
|
||||||
|
source_adv_sid = data[8]
|
||||||
|
broadcast_id = int.from_bytes(data[9:12], "little")
|
||||||
|
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
|
||||||
|
big_encryption = cls.BigEncryption(data[13])
|
||||||
|
if big_encryption == cls.BigEncryption.BAD_CODE:
|
||||||
|
bad_code = data[14:30]
|
||||||
|
subgroups = decode_subgroups(data[30:])
|
||||||
|
else:
|
||||||
|
bad_code = b""
|
||||||
|
subgroups = decode_subgroups(data[14:])
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
source_id,
|
||||||
|
source_address,
|
||||||
|
source_adv_sid,
|
||||||
|
broadcast_id,
|
||||||
|
pa_sync_state,
|
||||||
|
big_encryption,
|
||||||
|
bad_code,
|
||||||
|
subgroups,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return (
|
||||||
|
struct.pack(
|
||||||
|
"<BB6sB3sBB",
|
||||||
|
self.source_id,
|
||||||
|
self.source_address.address_type,
|
||||||
|
bytes(self.source_address),
|
||||||
|
self.source_adv_sid,
|
||||||
|
self.broadcast_id.to_bytes(3, "little"),
|
||||||
|
self.pa_sync_state,
|
||||||
|
self.big_encryption,
|
||||||
|
)
|
||||||
|
+ self.bad_code
|
||||||
|
+ encode_subgroups(self.subgroups)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class BroadcastAudioScanService(gatt.TemplateService):
|
||||||
|
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
|
||||||
|
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
|
||||||
|
gatt.Characteristic.Properties.WRITE
|
||||||
|
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||||
|
gatt.Characteristic.WRITEABLE,
|
||||||
|
gatt.CharacteristicValue(
|
||||||
|
write=self.on_broadcast_audio_scan_control_point_write
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.broadcast_receive_state_characteristic = gatt.Characteristic(
|
||||||
|
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
|
||||||
|
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
gatt.Characteristic.Permissions.READABLE
|
||||||
|
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
b"12", # TEST
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__([self.battery_level_characteristic])
|
||||||
|
|
||||||
|
def on_broadcast_audio_scan_control_point_write(
|
||||||
|
self, connection: device.Connection, value: bytes
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
|
||||||
|
SERVICE_CLASS = BroadcastAudioScanService
|
||||||
|
|
||||||
|
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
|
||||||
|
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
|
||||||
|
if not (
|
||||||
|
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise gatt.InvalidServiceError(
|
||||||
|
"Broadcast Audio Scan Control Point characteristic not found"
|
||||||
|
)
|
||||||
|
self.broadcast_audio_scan_control_point = characteristics[0]
|
||||||
|
|
||||||
|
if not (
|
||||||
|
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise gatt.InvalidServiceError(
|
||||||
|
"Broadcast Receive State characteristic not found"
|
||||||
|
)
|
||||||
|
self.broadcast_receive_states = [
|
||||||
|
gatt.DelegatedCharacteristicAdapter(
|
||||||
|
characteristic, decode=BroadcastReceiveState.from_bytes
|
||||||
|
)
|
||||||
|
for characteristic in characteristics
|
||||||
|
]
|
||||||
|
|
||||||
|
async def send_control_point_operation(
|
||||||
|
self, operation: ControlPointOperation
|
||||||
|
) -> None:
|
||||||
|
await self.broadcast_audio_scan_control_point.write_value(
|
||||||
|
bytes(operation), with_response=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remote_scan_started(self) -> None:
|
||||||
|
await self.send_control_point_operation(RemoteScanStartedOperation())
|
||||||
|
|
||||||
|
async def remote_scan_stopped(self) -> None:
|
||||||
|
await self.send_control_point_operation(RemoteScanStoppedOperation())
|
||||||
|
|
||||||
|
async def add_source(
|
||||||
|
self,
|
||||||
|
advertiser_address: hci.Address,
|
||||||
|
advertising_sid: int,
|
||||||
|
broadcast_id: int,
|
||||||
|
pa_sync: PeriodicAdvertisingSyncParams,
|
||||||
|
pa_interval: int,
|
||||||
|
subgroups: Sequence[SubgroupInfo],
|
||||||
|
) -> None:
|
||||||
|
await self.send_control_point_operation(
|
||||||
|
AddSourceOperation(
|
||||||
|
advertiser_address,
|
||||||
|
advertising_sid,
|
||||||
|
broadcast_id,
|
||||||
|
pa_sync,
|
||||||
|
pa_interval,
|
||||||
|
subgroups,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def modify_source(
|
||||||
|
self,
|
||||||
|
source_id: int,
|
||||||
|
pa_sync: PeriodicAdvertisingSyncParams,
|
||||||
|
pa_interval: int,
|
||||||
|
subgroups: Sequence[SubgroupInfo],
|
||||||
|
) -> None:
|
||||||
|
await self.send_control_point_operation(
|
||||||
|
ModifySourceOperation(
|
||||||
|
source_id,
|
||||||
|
pa_sync,
|
||||||
|
pa_interval,
|
||||||
|
subgroups,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remove_source(self, source_id: int) -> None:
|
||||||
|
await self.send_control_point_operation(RemoveSourceOperation(source_id))
|
||||||
@@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
|||||||
set_member_rank: Optional[int] = None,
|
set_member_rank: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||||
raise ValueError(
|
raise core.InvalidArgumentError(
|
||||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
|||||||
key = await connection.device.get_link_key(connection.peer_address)
|
key = await connection.device.get_link_key(connection.peer_address)
|
||||||
|
|
||||||
if not key:
|
if not key:
|
||||||
raise RuntimeError('LTK or LinkKey is not present')
|
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||||
|
|
||||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
|||||||
'''Reads SIRK and decrypts if encrypted.'''
|
'''Reads SIRK and decrypts if encrypted.'''
|
||||||
response = await self.set_identity_resolving_key.read_value()
|
response = await self.set_identity_resolving_key.read_value()
|
||||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||||
raise RuntimeError('Invalid SIRK value')
|
raise core.InvalidPacketError('Invalid SIRK value')
|
||||||
|
|
||||||
sirk_type = SirkType(response[0])
|
sirk_type = SirkType(response[0])
|
||||||
if sirk_type == SirkType.PLAINTEXT:
|
if sirk_type == SirkType.PLAINTEXT:
|
||||||
@@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
|||||||
key = await device.get_link_key(connection.peer_address)
|
key = await device.get_link_key(connection.peer_address)
|
||||||
|
|
||||||
if not key:
|
if not key:
|
||||||
raise RuntimeError('LTK or LinkKey is not present')
|
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||||
|
|
||||||
sirk = sef(key, response[1:])
|
sirk = sef(key, response[1:])
|
||||||
|
|
||||||
|
|||||||
110
bumble/profiles/gap.py
Normal file
110
bumble/profiles/gap.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# Copyright 2021-2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Generic Access Profile"""
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
from bumble.core import Appearance
|
||||||
|
from bumble.gatt import (
|
||||||
|
TemplateService,
|
||||||
|
Characteristic,
|
||||||
|
CharacteristicAdapter,
|
||||||
|
DelegatedCharacteristicAdapter,
|
||||||
|
UTF8CharacteristicAdapter,
|
||||||
|
GATT_GENERIC_ACCESS_SERVICE,
|
||||||
|
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||||
|
GATT_APPEARANCE_CHARACTERISTIC,
|
||||||
|
)
|
||||||
|
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Classes
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class GenericAccessService(TemplateService):
|
||||||
|
UUID = GATT_GENERIC_ACCESS_SERVICE
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
|
||||||
|
):
|
||||||
|
if isinstance(appearance, int):
|
||||||
|
appearance_int = appearance
|
||||||
|
elif isinstance(appearance, tuple):
|
||||||
|
appearance_int = (appearance[0] << 6) | appearance[1]
|
||||||
|
elif isinstance(appearance, Appearance):
|
||||||
|
appearance_int = int(appearance)
|
||||||
|
else:
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
|
self.device_name_characteristic = Characteristic(
|
||||||
|
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||||
|
Characteristic.Properties.READ,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
device_name.encode('utf-8')[:248],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.appearance_characteristic = Characteristic(
|
||||||
|
GATT_APPEARANCE_CHARACTERISTIC,
|
||||||
|
Characteristic.Properties.READ,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
struct.pack('<H', appearance_int),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
[self.device_name_characteristic, self.appearance_characteristic]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class GenericAccessServiceProxy(ProfileServiceProxy):
|
||||||
|
SERVICE_CLASS = GenericAccessService
|
||||||
|
|
||||||
|
device_name: Optional[CharacteristicAdapter]
|
||||||
|
appearance: Optional[DelegatedCharacteristicAdapter]
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: ServiceProxy):
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
GATT_DEVICE_NAME_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
|
||||||
|
else:
|
||||||
|
self.device_name = None
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
GATT_APPEARANCE_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.appearance = DelegatedCharacteristicAdapter(
|
||||||
|
characteristics[0],
|
||||||
|
decode=lambda value: Appearance.from_int(
|
||||||
|
struct.unpack_from('<H', value, 0)[0],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.appearance = None
|
||||||
@@ -19,6 +19,7 @@
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
from ..gatt_client import ProfileServiceProxy
|
from ..gatt_client import ProfileServiceProxy
|
||||||
from ..att import ATT_Error
|
from ..att import ATT_Error
|
||||||
from ..gatt import (
|
from ..gatt import (
|
||||||
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
|
|||||||
rr_intervals=None,
|
rr_intervals=None,
|
||||||
):
|
):
|
||||||
if heart_rate < 0 or heart_rate > 0xFFFF:
|
if heart_rate < 0 or heart_rate > 0xFFFF:
|
||||||
raise ValueError('heart_rate out of range')
|
raise core.InvalidArgumentError('heart_rate out of range')
|
||||||
|
|
||||||
if energy_expended is not None and (
|
if energy_expended is not None and (
|
||||||
energy_expended < 0 or energy_expended > 0xFFFF
|
energy_expended < 0 or energy_expended > 0xFFFF
|
||||||
):
|
):
|
||||||
raise ValueError('energy_expended out of range')
|
raise core.InvalidArgumentError('energy_expended out of range')
|
||||||
|
|
||||||
if rr_intervals:
|
if rr_intervals:
|
||||||
for rr_interval in rr_intervals:
|
for rr_interval in rr_intervals:
|
||||||
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
|
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
|
||||||
raise ValueError('rr_intervals out of range')
|
raise core.InvalidArgumentError('rr_intervals out of range')
|
||||||
|
|
||||||
self.heart_rate = heart_rate
|
self.heart_rate = heart_rate
|
||||||
self.sensor_contact_detected = sensor_contact_detected
|
self.sensor_contact_detected = sensor_contact_detected
|
||||||
|
|||||||
@@ -17,33 +17,67 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
import struct
|
||||||
|
from typing import List, Type
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from bumble import utils
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Classes
|
# Classes
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Metadata:
|
class Metadata:
|
||||||
|
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
|
||||||
|
|
||||||
|
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
|
||||||
|
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
|
||||||
|
again outside the lib.
|
||||||
|
'''
|
||||||
|
|
||||||
|
class Tag(utils.OpenIntEnum):
|
||||||
|
# fmt: off
|
||||||
|
PREFERRED_AUDIO_CONTEXTS = 0x01
|
||||||
|
STREAMING_AUDIO_CONTEXTS = 0x02
|
||||||
|
PROGRAM_INFO = 0x03
|
||||||
|
LANGUAGE = 0x04
|
||||||
|
CCID_LIST = 0x05
|
||||||
|
PARENTAL_RATING = 0x06
|
||||||
|
PROGRAM_INFO_URI = 0x07
|
||||||
|
AUDIO_ACTIVE_STATE = 0x08
|
||||||
|
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
|
||||||
|
ASSISTED_LISTENING_STREAM = 0x0A
|
||||||
|
BROADCAST_NAME = 0x0B
|
||||||
|
EXTENDED_METADATA = 0xFE
|
||||||
|
VENDOR_SPECIFIC = 0xFF
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Entry:
|
class Entry:
|
||||||
tag: int
|
tag: Metadata.Tag
|
||||||
data: bytes
|
data: bytes
|
||||||
|
|
||||||
entries: List[Entry]
|
@classmethod
|
||||||
|
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||||
|
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return bytes([len(self.data) + 1, self.tag]) + self.data
|
||||||
|
|
||||||
|
entries: List[Entry] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, data: bytes) -> Self:
|
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||||
entries = []
|
entries = []
|
||||||
offset = 0
|
offset = 0
|
||||||
length = len(data)
|
length = len(data)
|
||||||
while length >= 2:
|
while offset < length:
|
||||||
entry_length = data[offset]
|
entry_length = data[offset]
|
||||||
entry_tag = data[offset + 1]
|
offset += 1
|
||||||
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
|
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
|
||||||
entries.append(cls.Entry(entry_tag, entry_data))
|
|
||||||
length -= entry_length
|
|
||||||
offset += entry_length
|
offset += entry_length
|
||||||
|
|
||||||
return cls(entries)
|
return cls(entries)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return b''.join([bytes(entry) for entry in self.entries])
|
||||||
|
|||||||
448
bumble/profiles/mcp.py
Normal file
448
bumble/profiles/mcp.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
# Copyright 2021-2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
|
from bumble import device
|
||||||
|
from bumble import gatt
|
||||||
|
from bumble import gatt_client
|
||||||
|
from bumble import utils
|
||||||
|
|
||||||
|
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PlayingOrder(utils.OpenIntEnum):
|
||||||
|
'''See Media Control Service 3.15. Playing Order.'''
|
||||||
|
|
||||||
|
SINGLE_ONCE = 0x01
|
||||||
|
SINGLE_REPEAT = 0x02
|
||||||
|
IN_ORDER_ONCE = 0x03
|
||||||
|
IN_ORDER_REPEAT = 0x04
|
||||||
|
OLDEST_ONCE = 0x05
|
||||||
|
OLDEST_REPEAT = 0x06
|
||||||
|
NEWEST_ONCE = 0x07
|
||||||
|
NEWEST_REPEAT = 0x08
|
||||||
|
SHUFFLE_ONCE = 0x09
|
||||||
|
SHUFFLE_REPEAT = 0x0A
|
||||||
|
|
||||||
|
|
||||||
|
class PlayingOrderSupported(enum.IntFlag):
|
||||||
|
'''See Media Control Service 3.16. Playing Orders Supported.'''
|
||||||
|
|
||||||
|
SINGLE_ONCE = 0x0001
|
||||||
|
SINGLE_REPEAT = 0x0002
|
||||||
|
IN_ORDER_ONCE = 0x0004
|
||||||
|
IN_ORDER_REPEAT = 0x0008
|
||||||
|
OLDEST_ONCE = 0x0010
|
||||||
|
OLDEST_REPEAT = 0x0020
|
||||||
|
NEWEST_ONCE = 0x0040
|
||||||
|
NEWEST_REPEAT = 0x0080
|
||||||
|
SHUFFLE_ONCE = 0x0100
|
||||||
|
SHUFFLE_REPEAT = 0x0200
|
||||||
|
|
||||||
|
|
||||||
|
class MediaState(utils.OpenIntEnum):
|
||||||
|
'''See Media Control Service 3.17. Media State.'''
|
||||||
|
|
||||||
|
INACTIVE = 0x00
|
||||||
|
PLAYING = 0x01
|
||||||
|
PAUSED = 0x02
|
||||||
|
SEEKING = 0x03
|
||||||
|
|
||||||
|
|
||||||
|
class MediaControlPointOpcode(utils.OpenIntEnum):
|
||||||
|
'''See Media Control Service 3.18. Media Control Point.'''
|
||||||
|
|
||||||
|
PLAY = 0x01
|
||||||
|
PAUSE = 0x02
|
||||||
|
FAST_REWIND = 0x03
|
||||||
|
FAST_FORWARD = 0x04
|
||||||
|
STOP = 0x05
|
||||||
|
MOVE_RELATIVE = 0x10
|
||||||
|
PREVIOUS_SEGMENT = 0x20
|
||||||
|
NEXT_SEGMENT = 0x21
|
||||||
|
FIRST_SEGMENT = 0x22
|
||||||
|
LAST_SEGMENT = 0x23
|
||||||
|
GOTO_SEGMENT = 0x24
|
||||||
|
PREVIOUS_TRACK = 0x30
|
||||||
|
NEXT_TRACK = 0x31
|
||||||
|
FIRST_TRACK = 0x32
|
||||||
|
LAST_TRACK = 0x33
|
||||||
|
GOTO_TRACK = 0x34
|
||||||
|
PREVIOUS_GROUP = 0x40
|
||||||
|
NEXT_GROUP = 0x41
|
||||||
|
FIRST_GROUP = 0x42
|
||||||
|
LAST_GROUP = 0x43
|
||||||
|
GOTO_GROUP = 0x44
|
||||||
|
|
||||||
|
|
||||||
|
class MediaControlPointResultCode(enum.IntFlag):
|
||||||
|
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
|
||||||
|
|
||||||
|
SUCCESS = 0x01
|
||||||
|
OPCODE_NOT_SUPPORTED = 0x02
|
||||||
|
MEDIA_PLAYER_INACTIVE = 0x03
|
||||||
|
COMMAND_CANNOT_BE_COMPLETED = 0x04
|
||||||
|
|
||||||
|
|
||||||
|
class MediaControlPointOpcodeSupported(enum.IntFlag):
|
||||||
|
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
|
||||||
|
|
||||||
|
PLAY = 0x00000001
|
||||||
|
PAUSE = 0x00000002
|
||||||
|
FAST_REWIND = 0x00000004
|
||||||
|
FAST_FORWARD = 0x00000008
|
||||||
|
STOP = 0x00000010
|
||||||
|
MOVE_RELATIVE = 0x00000020
|
||||||
|
PREVIOUS_SEGMENT = 0x00000040
|
||||||
|
NEXT_SEGMENT = 0x00000080
|
||||||
|
FIRST_SEGMENT = 0x00000100
|
||||||
|
LAST_SEGMENT = 0x00000200
|
||||||
|
GOTO_SEGMENT = 0x00000400
|
||||||
|
PREVIOUS_TRACK = 0x00000800
|
||||||
|
NEXT_TRACK = 0x00001000
|
||||||
|
FIRST_TRACK = 0x00002000
|
||||||
|
LAST_TRACK = 0x00004000
|
||||||
|
GOTO_TRACK = 0x00008000
|
||||||
|
PREVIOUS_GROUP = 0x00010000
|
||||||
|
NEXT_GROUP = 0x00020000
|
||||||
|
FIRST_GROUP = 0x00040000
|
||||||
|
LAST_GROUP = 0x00080000
|
||||||
|
GOTO_GROUP = 0x00100000
|
||||||
|
|
||||||
|
|
||||||
|
class SearchControlPointItemType(utils.OpenIntEnum):
|
||||||
|
'''See Media Control Service 3.20. Search Control Point.'''
|
||||||
|
|
||||||
|
TRACK_NAME = 0x01
|
||||||
|
ARTIST_NAME = 0x02
|
||||||
|
ALBUM_NAME = 0x03
|
||||||
|
GROUP_NAME = 0x04
|
||||||
|
EARLIEST_YEAR = 0x05
|
||||||
|
LATEST_YEAR = 0x06
|
||||||
|
GENRE = 0x07
|
||||||
|
ONLY_TRACKS = 0x08
|
||||||
|
ONLY_GROUPS = 0x09
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectType(utils.OpenIntEnum):
|
||||||
|
'''See Media Control Service 4.4.1. Object Type field.'''
|
||||||
|
|
||||||
|
TASK = 0
|
||||||
|
GROUP = 1
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Classes
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectId(int):
|
||||||
|
'''See Media Control Service 4.4.2. Object ID field.'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||||
|
return cls(int.from_bytes(data, byteorder='little', signed=False))
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return self.to_bytes(6, 'little')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GroupObjectType:
|
||||||
|
'''See Media Control Service 4.4. Group Object Type.'''
|
||||||
|
|
||||||
|
object_type: ObjectType
|
||||||
|
object_id: ObjectId
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||||
|
return cls(
|
||||||
|
object_type=ObjectType(data[0]),
|
||||||
|
object_id=ObjectId.create_from_bytes(data[1:]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
return bytes([self.object_type]) + bytes(self.object_id)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Server
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class MediaControlService(gatt.TemplateService):
|
||||||
|
'''Media Control Service server implementation, only for testing currently.'''
|
||||||
|
|
||||||
|
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
|
||||||
|
|
||||||
|
def __init__(self, media_player_name: Optional[str] = None) -> None:
|
||||||
|
self.track_position = 0
|
||||||
|
|
||||||
|
self.media_player_name_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=media_player_name or 'Bumble Player',
|
||||||
|
)
|
||||||
|
self.track_changed_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.track_title_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.track_duration_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.track_position_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.WRITE
|
||||||
|
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||||
|
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.media_state_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.media_control_point_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.WRITE
|
||||||
|
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||||
|
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||||
|
value=gatt.CharacteristicValue(write=self.on_media_control_point),
|
||||||
|
)
|
||||||
|
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
self.content_control_id_characteristic = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||||
|
value=b'',
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
[
|
||||||
|
self.media_player_name_characteristic,
|
||||||
|
self.track_changed_characteristic,
|
||||||
|
self.track_title_characteristic,
|
||||||
|
self.track_duration_characteristic,
|
||||||
|
self.track_position_characteristic,
|
||||||
|
self.media_state_characteristic,
|
||||||
|
self.media_control_point_characteristic,
|
||||||
|
self.media_control_point_opcodes_supported_characteristic,
|
||||||
|
self.content_control_id_characteristic,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_media_control_point(
|
||||||
|
self, connection: Optional[device.Connection], data: bytes
|
||||||
|
) -> None:
|
||||||
|
if not connection:
|
||||||
|
raise core.InvalidStateError()
|
||||||
|
|
||||||
|
opcode = MediaControlPointOpcode(data[0])
|
||||||
|
|
||||||
|
await connection.device.notify_subscriber(
|
||||||
|
connection,
|
||||||
|
self.media_control_point_characteristic,
|
||||||
|
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericMediaControlService(MediaControlService):
|
||||||
|
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Client
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class MediaControlServiceProxy(
|
||||||
|
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
|
||||||
|
):
|
||||||
|
SERVICE_CLASS = MediaControlService
|
||||||
|
|
||||||
|
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
|
||||||
|
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||||
|
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
|
||||||
|
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||||
|
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||||
|
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||||
|
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||||
|
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
|
||||||
|
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
|
||||||
|
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
|
||||||
|
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
|
||||||
|
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||||
|
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||||
|
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||||
|
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
|
||||||
|
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
|
||||||
|
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||||
|
}
|
||||||
|
|
||||||
|
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
track_changed: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
track_title: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
track_duration: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
track_position: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
playing_order: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
media_state: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
media_control_point_notifications: asyncio.Queue[bytes]
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||||
|
utils.CompositeEventEmitter.__init__(self)
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.media_control_point_notifications = asyncio.Queue()
|
||||||
|
|
||||||
|
for field, uuid in self._CHARACTERISTICS.items():
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
|
||||||
|
setattr(self, field, characteristics[0])
|
||||||
|
|
||||||
|
async def subscribe_characteristics(self) -> None:
|
||||||
|
if self.media_control_point:
|
||||||
|
await self.media_control_point.subscribe(self._on_media_control_point)
|
||||||
|
if self.media_state:
|
||||||
|
await self.media_state.subscribe(self._on_media_state)
|
||||||
|
if self.track_changed:
|
||||||
|
await self.track_changed.subscribe(self._on_track_changed)
|
||||||
|
if self.track_title:
|
||||||
|
await self.track_title.subscribe(self._on_track_title)
|
||||||
|
if self.track_duration:
|
||||||
|
await self.track_duration.subscribe(self._on_track_duration)
|
||||||
|
if self.track_position:
|
||||||
|
await self.track_position.subscribe(self._on_track_position)
|
||||||
|
|
||||||
|
async def write_control_point(
|
||||||
|
self, opcode: MediaControlPointOpcode
|
||||||
|
) -> MediaControlPointResultCode:
|
||||||
|
'''Writes a Media Control Point Opcode to peer and waits for the notification.
|
||||||
|
|
||||||
|
The write operation will be executed when there isn't other pending commands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opcode: opcode defined in `MediaControlPointOpcode`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response code provided in `MediaControlPointResultCode`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidOperationError: Server does not have Media Control Point Characteristic.
|
||||||
|
InvalidStateError: Server replies a notification with mismatched opcode.
|
||||||
|
'''
|
||||||
|
if not self.media_control_point:
|
||||||
|
raise core.InvalidOperationError("Peer does not have media control point")
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
await self.media_control_point.write_value(
|
||||||
|
bytes([opcode]),
|
||||||
|
with_response=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
response_opcode,
|
||||||
|
response_code,
|
||||||
|
) = await self.media_control_point_notifications.get()
|
||||||
|
if response_opcode != opcode:
|
||||||
|
raise core.InvalidStateError(
|
||||||
|
f"Expected {opcode} notification, but get {response_opcode}"
|
||||||
|
)
|
||||||
|
return MediaControlPointResultCode(response_code)
|
||||||
|
|
||||||
|
def _on_media_control_point(self, data: bytes) -> None:
|
||||||
|
self.media_control_point_notifications.put_nowait(data)
|
||||||
|
|
||||||
|
def _on_media_state(self, data: bytes) -> None:
|
||||||
|
self.emit('media_state', MediaState(data[0]))
|
||||||
|
|
||||||
|
def _on_track_changed(self, data: bytes) -> None:
|
||||||
|
del data
|
||||||
|
self.emit('track_changed')
|
||||||
|
|
||||||
|
def _on_track_title(self, data: bytes) -> None:
|
||||||
|
self.emit('track_title', data.decode("utf-8"))
|
||||||
|
|
||||||
|
def _on_track_duration(self, data: bytes) -> None:
|
||||||
|
self.emit('track_duration', struct.unpack_from('<i', data)[0])
|
||||||
|
|
||||||
|
def _on_track_position(self, data: bytes) -> None:
|
||||||
|
self.emit('track_position', struct.unpack_from('<i', data)[0])
|
||||||
|
|
||||||
|
|
||||||
|
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
|
||||||
|
SERVICE_CLASS = GenericMediaControlService
|
||||||
210
bumble/profiles/pacs.py
Normal file
210
bumble/profiles/pacs.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
# Copyright 2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for
|
||||||
|
|
||||||
|
"""LE Audio - Published Audio Capabilities Service"""
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
|
||||||
|
from bumble.profiles import le_audio
|
||||||
|
from bumble import gatt
|
||||||
|
from bumble import gatt_client
|
||||||
|
from bumble import hci
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PacRecord:
|
||||||
|
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
|
||||||
|
|
||||||
|
coding_format: hci.CodingFormat
|
||||||
|
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||||
|
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> PacRecord:
|
||||||
|
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
|
||||||
|
codec_specific_capabilities_size = data[offset]
|
||||||
|
|
||||||
|
offset += 1
|
||||||
|
codec_specific_capabilities_bytes = data[
|
||||||
|
offset : offset + codec_specific_capabilities_size
|
||||||
|
]
|
||||||
|
offset += codec_specific_capabilities_size
|
||||||
|
metadata_size = data[offset]
|
||||||
|
offset += 1
|
||||||
|
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
|
||||||
|
|
||||||
|
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||||
|
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||||
|
codec_specific_capabilities = codec_specific_capabilities_bytes
|
||||||
|
else:
|
||||||
|
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
|
||||||
|
codec_specific_capabilities_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
return PacRecord(
|
||||||
|
coding_format=coding_format,
|
||||||
|
codec_specific_capabilities=codec_specific_capabilities,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
||||||
|
metadata_bytes = bytes(self.metadata)
|
||||||
|
return (
|
||||||
|
bytes(self.coding_format)
|
||||||
|
+ bytes([len(capabilities_bytes)])
|
||||||
|
+ capabilities_bytes
|
||||||
|
+ bytes([len(metadata_bytes)])
|
||||||
|
+ metadata_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Server
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||||
|
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
|
||||||
|
|
||||||
|
sink_pac: Optional[gatt.Characteristic]
|
||||||
|
sink_audio_locations: Optional[gatt.Characteristic]
|
||||||
|
source_pac: Optional[gatt.Characteristic]
|
||||||
|
source_audio_locations: Optional[gatt.Characteristic]
|
||||||
|
available_audio_contexts: gatt.Characteristic
|
||||||
|
supported_audio_contexts: gatt.Characteristic
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
supported_source_context: ContextType,
|
||||||
|
supported_sink_context: ContextType,
|
||||||
|
available_source_context: ContextType,
|
||||||
|
available_sink_context: ContextType,
|
||||||
|
sink_pac: Sequence[PacRecord] = (),
|
||||||
|
sink_audio_locations: Optional[AudioLocation] = None,
|
||||||
|
source_pac: Sequence[PacRecord] = (),
|
||||||
|
source_audio_locations: Optional[AudioLocation] = None,
|
||||||
|
) -> None:
|
||||||
|
characteristics = []
|
||||||
|
|
||||||
|
self.supported_audio_contexts = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=struct.pack('<HH', supported_sink_context, supported_source_context),
|
||||||
|
)
|
||||||
|
characteristics.append(self.supported_audio_contexts)
|
||||||
|
|
||||||
|
self.available_audio_contexts = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ
|
||||||
|
| gatt.Characteristic.Properties.NOTIFY,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=struct.pack('<HH', available_sink_context, available_source_context),
|
||||||
|
)
|
||||||
|
characteristics.append(self.available_audio_contexts)
|
||||||
|
|
||||||
|
if sink_pac:
|
||||||
|
self.sink_pac = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
|
||||||
|
)
|
||||||
|
characteristics.append(self.sink_pac)
|
||||||
|
|
||||||
|
if sink_audio_locations is not None:
|
||||||
|
self.sink_audio_locations = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=struct.pack('<I', sink_audio_locations),
|
||||||
|
)
|
||||||
|
characteristics.append(self.sink_audio_locations)
|
||||||
|
|
||||||
|
if source_pac:
|
||||||
|
self.source_pac = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
|
||||||
|
)
|
||||||
|
characteristics.append(self.source_pac)
|
||||||
|
|
||||||
|
if source_audio_locations is not None:
|
||||||
|
self.source_audio_locations = gatt.Characteristic(
|
||||||
|
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
|
||||||
|
properties=gatt.Characteristic.Properties.READ,
|
||||||
|
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||||
|
value=struct.pack('<I', source_audio_locations),
|
||||||
|
)
|
||||||
|
characteristics.append(self.source_audio_locations)
|
||||||
|
|
||||||
|
super().__init__(characteristics)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Client
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
||||||
|
SERVICE_CLASS = PublishedAudioCapabilitiesService
|
||||||
|
|
||||||
|
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
source_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||||
|
available_audio_contexts: gatt_client.CharacteristicProxy
|
||||||
|
supported_audio_contexts: gatt_client.CharacteristicProxy
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
|
||||||
|
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||||
|
)[0]
|
||||||
|
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SINK_PAC_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.sink_pac = characteristics[0]
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.source_pac = characteristics[0]
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.sink_audio_locations = characteristics[0]
|
||||||
|
|
||||||
|
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
||||||
|
):
|
||||||
|
self.source_audio_locations = characteristics[0]
|
||||||
89
bumble/profiles/tmap.py
Normal file
89
bumble/profiles/tmap.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2021-2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""LE Audio - Telephony and Media Audio Profile"""
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from bumble.gatt import (
|
||||||
|
TemplateService,
|
||||||
|
Characteristic,
|
||||||
|
DelegatedCharacteristicAdapter,
|
||||||
|
InvalidServiceError,
|
||||||
|
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
|
||||||
|
GATT_TMAP_ROLE_CHARACTERISTIC,
|
||||||
|
)
|
||||||
|
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Classes
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class Role(enum.IntFlag):
|
||||||
|
CALL_GATEWAY = 1 << 0
|
||||||
|
CALL_TERMINAL = 1 << 1
|
||||||
|
UNICAST_MEDIA_SENDER = 1 << 2
|
||||||
|
UNICAST_MEDIA_RECEIVER = 1 << 3
|
||||||
|
BROADCAST_MEDIA_SENDER = 1 << 4
|
||||||
|
BROADCAST_MEDIA_RECEIVER = 1 << 5
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class TelephonyAndMediaAudioService(TemplateService):
|
||||||
|
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
|
||||||
|
|
||||||
|
def __init__(self, role: Role):
|
||||||
|
self.role_characteristic = Characteristic(
|
||||||
|
GATT_TMAP_ROLE_CHARACTERISTIC,
|
||||||
|
Characteristic.Properties.READ,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
struct.pack('<H', int(role)),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__([self.role_characteristic])
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
|
||||||
|
SERVICE_CLASS = TelephonyAndMediaAudioService
|
||||||
|
|
||||||
|
role: DelegatedCharacteristicAdapter
|
||||||
|
|
||||||
|
def __init__(self, service_proxy: ServiceProxy):
|
||||||
|
self.service_proxy = service_proxy
|
||||||
|
|
||||||
|
if not (
|
||||||
|
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||||
|
GATT_TMAP_ROLE_CHARACTERISTIC
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise InvalidServiceError('TMAP Role characteristic not found')
|
||||||
|
|
||||||
|
self.role = DelegatedCharacteristicAdapter(
|
||||||
|
characteristics[0],
|
||||||
|
decode=lambda value: Role(
|
||||||
|
struct.unpack_from('<H', value, 0)[0],
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -36,7 +36,9 @@ from .core import (
|
|||||||
BT_RFCOMM_PROTOCOL_ID,
|
BT_RFCOMM_PROTOCOL_ID,
|
||||||
BT_BR_EDR_TRANSPORT,
|
BT_BR_EDR_TRANSPORT,
|
||||||
BT_L2CAP_PROTOCOL_ID,
|
BT_L2CAP_PROTOCOL_ID,
|
||||||
|
InvalidArgumentError,
|
||||||
InvalidStateError,
|
InvalidStateError,
|
||||||
|
InvalidPacketError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -335,7 +337,7 @@ class RFCOMM_Frame:
|
|||||||
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
|
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
|
||||||
if frame.fcs != fcs:
|
if frame.fcs != fcs:
|
||||||
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
|
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
|
||||||
raise ValueError('fcs mismatch')
|
raise InvalidPacketError('fcs mismatch')
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@@ -713,7 +715,7 @@ class DLC(EventEmitter):
|
|||||||
# Automatically convert strings to bytes using UTF-8
|
# Automatically convert strings to bytes using UTF-8
|
||||||
data = data.encode('utf-8')
|
data = data.encode('utf-8')
|
||||||
else:
|
else:
|
||||||
raise ValueError('write only accept bytes or strings')
|
raise InvalidArgumentError('write only accept bytes or strings')
|
||||||
|
|
||||||
self.tx_buffer += data
|
self.tx_buffer += data
|
||||||
self.drained.clear()
|
self.drained.clear()
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from . import core, l2cap
|
from . import core, l2cap
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .core import InvalidStateError
|
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
|
||||||
from .hci import HCI_Object, name_or_number, key_with_value
|
from .hci import HCI_Object, name_or_number, key_with_value
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -189,7 +189,9 @@ class DataElement:
|
|||||||
self.bytes = None
|
self.bytes = None
|
||||||
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||||
if value_size is None:
|
if value_size is None:
|
||||||
raise ValueError('integer types must have a value size specified')
|
raise InvalidArgumentError(
|
||||||
|
'integer types must have a value size specified'
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def nil() -> DataElement:
|
def nil() -> DataElement:
|
||||||
@@ -265,7 +267,7 @@ class DataElement:
|
|||||||
if len(data) == 8:
|
if len(data) == 8:
|
||||||
return struct.unpack('>Q', data)[0]
|
return struct.unpack('>Q', data)[0]
|
||||||
|
|
||||||
raise ValueError(f'invalid integer length {len(data)}')
|
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signed_integer_from_bytes(data):
|
def signed_integer_from_bytes(data):
|
||||||
@@ -281,7 +283,7 @@ class DataElement:
|
|||||||
if len(data) == 8:
|
if len(data) == 8:
|
||||||
return struct.unpack('>q', data)[0]
|
return struct.unpack('>q', data)[0]
|
||||||
|
|
||||||
raise ValueError(f'invalid integer length {len(data)}')
|
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_from_bytes(data):
|
def list_from_bytes(data):
|
||||||
@@ -354,7 +356,7 @@ class DataElement:
|
|||||||
data = b''
|
data = b''
|
||||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||||
if self.value < 0:
|
if self.value < 0:
|
||||||
raise ValueError('UNSIGNED_INTEGER cannot be negative')
|
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||||
|
|
||||||
if self.value_size == 1:
|
if self.value_size == 1:
|
||||||
data = struct.pack('B', self.value)
|
data = struct.pack('B', self.value)
|
||||||
@@ -365,7 +367,7 @@ class DataElement:
|
|||||||
elif self.value_size == 8:
|
elif self.value_size == 8:
|
||||||
data = struct.pack('>Q', self.value)
|
data = struct.pack('>Q', self.value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid value_size')
|
raise InvalidArgumentError('invalid value_size')
|
||||||
elif self.type == DataElement.SIGNED_INTEGER:
|
elif self.type == DataElement.SIGNED_INTEGER:
|
||||||
if self.value_size == 1:
|
if self.value_size == 1:
|
||||||
data = struct.pack('b', self.value)
|
data = struct.pack('b', self.value)
|
||||||
@@ -376,7 +378,7 @@ class DataElement:
|
|||||||
elif self.value_size == 8:
|
elif self.value_size == 8:
|
||||||
data = struct.pack('>q', self.value)
|
data = struct.pack('>q', self.value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid value_size')
|
raise InvalidArgumentError('invalid value_size')
|
||||||
elif self.type == DataElement.UUID:
|
elif self.type == DataElement.UUID:
|
||||||
data = bytes(reversed(bytes(self.value)))
|
data = bytes(reversed(bytes(self.value)))
|
||||||
elif self.type == DataElement.URL:
|
elif self.type == DataElement.URL:
|
||||||
@@ -392,7 +394,7 @@ class DataElement:
|
|||||||
size_bytes = b''
|
size_bytes = b''
|
||||||
if self.type == DataElement.NIL:
|
if self.type == DataElement.NIL:
|
||||||
if size != 0:
|
if size != 0:
|
||||||
raise ValueError('NIL must be empty')
|
raise InvalidArgumentError('NIL must be empty')
|
||||||
size_index = 0
|
size_index = 0
|
||||||
elif self.type in (
|
elif self.type in (
|
||||||
DataElement.UNSIGNED_INTEGER,
|
DataElement.UNSIGNED_INTEGER,
|
||||||
@@ -410,7 +412,7 @@ class DataElement:
|
|||||||
elif size == 16:
|
elif size == 16:
|
||||||
size_index = 4
|
size_index = 4
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid data size')
|
raise InvalidArgumentError('invalid data size')
|
||||||
elif self.type in (
|
elif self.type in (
|
||||||
DataElement.TEXT_STRING,
|
DataElement.TEXT_STRING,
|
||||||
DataElement.SEQUENCE,
|
DataElement.SEQUENCE,
|
||||||
@@ -427,10 +429,10 @@ class DataElement:
|
|||||||
size_index = 7
|
size_index = 7
|
||||||
size_bytes = struct.pack('>I', size)
|
size_bytes = struct.pack('>I', size)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid data size')
|
raise InvalidArgumentError('invalid data size')
|
||||||
elif self.type == DataElement.BOOLEAN:
|
elif self.type == DataElement.BOOLEAN:
|
||||||
if size != 1:
|
if size != 1:
|
||||||
raise ValueError('boolean must be 1 byte')
|
raise InvalidArgumentError('boolean must be 1 byte')
|
||||||
size_index = 0
|
size_index = 0
|
||||||
|
|
||||||
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from .core import (
|
|||||||
BT_CENTRAL_ROLE,
|
BT_CENTRAL_ROLE,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
AdvertisingData,
|
AdvertisingData,
|
||||||
|
InvalidArgumentError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
name_or_number,
|
name_or_number,
|
||||||
)
|
)
|
||||||
@@ -766,8 +767,11 @@ class Session:
|
|||||||
self.oob_data_flag = 0 if pairing_config.oob is None else 1
|
self.oob_data_flag = 0 if pairing_config.oob is None else 1
|
||||||
|
|
||||||
# Set up addresses
|
# Set up addresses
|
||||||
self_address = connection.self_address
|
self_address = connection.self_resolvable_address or connection.self_address
|
||||||
peer_address = connection.peer_resolvable_address or connection.peer_address
|
peer_address = connection.peer_resolvable_address or connection.peer_address
|
||||||
|
logger.debug(
|
||||||
|
f"pairing with self_address={self_address}, peer_address={peer_address}"
|
||||||
|
)
|
||||||
if self.is_initiator:
|
if self.is_initiator:
|
||||||
self.ia = bytes(self_address)
|
self.ia = bytes(self_address)
|
||||||
self.iat = 1 if self_address.is_random else 0
|
self.iat = 1 if self_address.is_random else 0
|
||||||
@@ -784,7 +788,7 @@ class Session:
|
|||||||
self.peer_oob_data = pairing_config.oob.peer_data
|
self.peer_oob_data = pairing_config.oob.peer_data
|
||||||
if pairing_config.sc:
|
if pairing_config.sc:
|
||||||
if pairing_config.oob.our_context is None:
|
if pairing_config.oob.our_context is None:
|
||||||
raise ValueError(
|
raise InvalidArgumentError(
|
||||||
"oob pairing config requires a context when sc is True"
|
"oob pairing config requires a context when sc is True"
|
||||||
)
|
)
|
||||||
self.r = pairing_config.oob.our_context.r
|
self.r = pairing_config.oob.our_context.r
|
||||||
@@ -793,7 +797,7 @@ class Session:
|
|||||||
self.tk = pairing_config.oob.legacy_context.tk
|
self.tk = pairing_config.oob.legacy_context.tk
|
||||||
else:
|
else:
|
||||||
if pairing_config.oob.legacy_context is None:
|
if pairing_config.oob.legacy_context is None:
|
||||||
raise ValueError(
|
raise InvalidArgumentError(
|
||||||
"oob pairing config requires a legacy context when sc is False"
|
"oob pairing config requires a legacy context when sc is False"
|
||||||
)
|
)
|
||||||
self.r = bytes(16)
|
self.r = bytes(16)
|
||||||
@@ -1074,11 +1078,19 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def send_identity_address_command(self) -> None:
|
def send_identity_address_command(self) -> None:
|
||||||
identity_address = {
|
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
|
||||||
None: self.connection.self_address,
|
identity_address = self.manager.device.public_address
|
||||||
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
|
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
|
||||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
|
identity_address = self.manager.device.static_address
|
||||||
}[self.pairing_config.identity_address_type]
|
else:
|
||||||
|
# No identity address type set. If the controller has a public address, it
|
||||||
|
# will be more responsible to be the identity address.
|
||||||
|
if self.manager.device.public_address != Address.ANY:
|
||||||
|
logger.debug("No identity address type set, using PUBLIC")
|
||||||
|
identity_address = self.manager.device.public_address
|
||||||
|
else:
|
||||||
|
logger.debug("No identity address type set, using RANDOM")
|
||||||
|
identity_address = self.manager.device.static_address
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Identity_Address_Information_Command(
|
SMP_Identity_Address_Information_Command(
|
||||||
addr_type=identity_address.address_type,
|
addr_type=identity_address.address_type,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import datetime
|
|||||||
from typing import BinaryIO, Generator
|
from typing import BinaryIO, Generator
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||||
|
|
||||||
|
|
||||||
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if ':' not in spec:
|
if ':' not in spec:
|
||||||
raise ValueError('snooper type prefix missing')
|
raise core.InvalidArgumentError('snooper type prefix missing')
|
||||||
|
|
||||||
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
||||||
|
|
||||||
if snooper_type == 'btsnoop':
|
if snooper_type == 'btsnoop':
|
||||||
if ':' not in snooper_args:
|
if ':' not in snooper_args:
|
||||||
raise ValueError('I/O type for btsnoop snooper type missing')
|
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
|
||||||
|
|
||||||
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
||||||
if io_type == 'file':
|
if io_type == 'file':
|
||||||
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
|||||||
_SNOOPER_INSTANCE_COUNT -= 1
|
_SNOOPER_INSTANCE_COUNT -= 1
|
||||||
return
|
return
|
||||||
|
|
||||||
raise ValueError(f'I/O type {io_type} not supported')
|
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
|
||||||
|
|
||||||
raise ValueError(f'snooper type {snooper_type} not found')
|
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
|
||||||
from ..snoop import create_snooper
|
from ..snoop import create_snooper
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -180,7 +180,13 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
|||||||
|
|
||||||
return await open_android_netsim_transport(spec)
|
return await open_android_netsim_transport(spec)
|
||||||
|
|
||||||
raise ValueError('unknown transport scheme')
|
if scheme == 'unix':
|
||||||
|
from .unix import open_unix_client_transport
|
||||||
|
|
||||||
|
assert spec
|
||||||
|
return await open_unix_client_transport(spec)
|
||||||
|
|
||||||
|
raise TransportSpecError('unknown transport scheme')
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -20,7 +20,13 @@ import grpc.aio
|
|||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
from .common import (
|
||||||
|
PumpedTransport,
|
||||||
|
PumpedPacketSource,
|
||||||
|
PumpedPacketSink,
|
||||||
|
Transport,
|
||||||
|
TransportSpecError,
|
||||||
|
)
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
||||||
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
|||||||
elif ':' in param:
|
elif ':' in param:
|
||||||
server_host, server_port = param.split(':')
|
server_host, server_port = param.split(':')
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid parameter')
|
raise TransportSpecError('invalid parameter')
|
||||||
|
|
||||||
# Connect to the gRPC server
|
# Connect to the gRPC server
|
||||||
server_address = f'{server_host}:{server_port}'
|
server_address = f'{server_host}:{server_port}'
|
||||||
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
|||||||
service = VhciForwardingServiceStub(channel)
|
service = VhciForwardingServiceStub(channel)
|
||||||
hci_device = HciDevice(service.attachVhci())
|
hci_device = HciDevice(service.attachVhci())
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid mode')
|
raise TransportSpecError('invalid mode')
|
||||||
|
|
||||||
# Create the transport object
|
# Create the transport object
|
||||||
class EmulatorTransport(PumpedTransport):
|
class EmulatorTransport(PumpedTransport):
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from .common import (
|
|||||||
PumpedPacketSource,
|
PumpedPacketSource,
|
||||||
PumpedPacketSink,
|
PumpedPacketSink,
|
||||||
Transport,
|
Transport,
|
||||||
|
TransportSpecError,
|
||||||
|
TransportInitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
|
|||||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||||
) -> Transport:
|
) -> Transport:
|
||||||
if not server_port:
|
if not server_port:
|
||||||
raise ValueError('invalid port')
|
raise TransportSpecError('invalid port')
|
||||||
if server_host == '_' or not server_host:
|
if server_host == '_' or not server_host:
|
||||||
server_host = 'localhost'
|
server_host = 'localhost'
|
||||||
|
|
||||||
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
|
|||||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||||
server_port = find_grpc_port(instance_number)
|
server_port = find_grpc_port(instance_number)
|
||||||
if not server_port:
|
if not server_port:
|
||||||
raise RuntimeError('gRPC server port not found')
|
raise TransportInitError('gRPC server port not found')
|
||||||
|
|
||||||
# Connect to the gRPC server
|
# Connect to the gRPC server
|
||||||
server_address = f'{server_host}:{server_port}'
|
server_address = f'{server_host}:{server_port}'
|
||||||
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
|
|||||||
|
|
||||||
if response_type == 'error':
|
if response_type == 'error':
|
||||||
logger.warning(f'received error: {response.error}')
|
logger.warning(f'received error: {response.error}')
|
||||||
raise RuntimeError(response.error)
|
raise TransportInitError(response.error)
|
||||||
|
|
||||||
if response_type == 'hci_packet':
|
if response_type == 'hci_packet':
|
||||||
return (
|
return (
|
||||||
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
|
|||||||
+ response.hci_packet.packet
|
+ response.hci_packet.packet
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError('unsupported response type')
|
raise TransportSpecError('unsupported response type')
|
||||||
|
|
||||||
async def write(self, packet):
|
async def write(self, packet):
|
||||||
await self.hci_device.write(
|
await self.hci_device.write(
|
||||||
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
|||||||
options: Dict[str, str] = {}
|
options: Dict[str, str] = {}
|
||||||
for param in params[params_offset:]:
|
for param in params[params_offset:]:
|
||||||
if '=' not in param:
|
if '=' not in param:
|
||||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
raise TransportSpecError('invalid parameter, expected <name>=<value>')
|
||||||
option_name, option_value = param.split('=')
|
option_name, option_value = param.split('=')
|
||||||
options[option_name] = option_value
|
options[option_name] = option_value
|
||||||
|
|
||||||
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
|||||||
)
|
)
|
||||||
if mode == 'controller':
|
if mode == 'controller':
|
||||||
if host is None:
|
if host is None:
|
||||||
raise ValueError('<host>:<port> missing')
|
raise TransportSpecError('<host>:<port> missing')
|
||||||
return await open_android_netsim_controller_transport(host, port, options)
|
return await open_android_netsim_controller_transport(host, port, options)
|
||||||
|
|
||||||
raise ValueError('invalid mode option')
|
raise TransportSpecError('invalid mode option')
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import logging
|
|||||||
import io
|
import io
|
||||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||||
|
|
||||||
|
from bumble import core
|
||||||
from bumble import hci
|
from bumble import hci
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.snoop import Snooper
|
from bumble.snoop import Snooper
|
||||||
@@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Errors
|
# Errors
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class TransportLostError(Exception):
|
class TransportLostError(core.BaseBumbleError, RuntimeError):
|
||||||
"""
|
"""The Transport has been lost/disconnected."""
|
||||||
The Transport has been lost/disconnected.
|
|
||||||
"""
|
|
||||||
|
class TransportInitError(core.BaseBumbleError, RuntimeError):
|
||||||
|
"""Error raised when the transport cannot be initialized."""
|
||||||
|
|
||||||
|
|
||||||
|
class TransportSpecError(core.BaseBumbleError, ValueError):
|
||||||
|
"""Error raised when the transport spec is invalid."""
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -132,7 +139,9 @@ class PacketParser:
|
|||||||
packet_type
|
packet_type
|
||||||
) or self.extended_packet_info.get(packet_type)
|
) or self.extended_packet_info.get(packet_type)
|
||||||
if self.packet_info is None:
|
if self.packet_info is None:
|
||||||
raise ValueError(f'invalid packet type {packet_type}')
|
raise core.InvalidPacketError(
|
||||||
|
f'invalid packet type {packet_type}'
|
||||||
|
)
|
||||||
self.state = PacketParser.NEED_LENGTH
|
self.state = PacketParser.NEED_LENGTH
|
||||||
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
||||||
elif self.state == PacketParser.NEED_LENGTH:
|
elif self.state == PacketParser.NEED_LENGTH:
|
||||||
@@ -178,19 +187,19 @@ class PacketReader:
|
|||||||
# Get the packet info based on its type
|
# Get the packet info based on its type
|
||||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||||
if packet_info is None:
|
if packet_info is None:
|
||||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||||
|
|
||||||
# Read the header (that includes the length)
|
# Read the header (that includes the length)
|
||||||
header_size = packet_info[0] + packet_info[1]
|
header_size = packet_info[0] + packet_info[1]
|
||||||
header = self.source.read(header_size)
|
header = self.source.read(header_size)
|
||||||
if len(header) != header_size:
|
if len(header) != header_size:
|
||||||
raise ValueError('packet too short')
|
raise core.InvalidPacketError('packet too short')
|
||||||
|
|
||||||
# Read the body
|
# Read the body
|
||||||
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
||||||
body = self.source.read(body_length)
|
body = self.source.read(body_length)
|
||||||
if len(body) != body_length:
|
if len(body) != body_length:
|
||||||
raise ValueError('packet too short')
|
raise core.InvalidPacketError('packet too short')
|
||||||
|
|
||||||
return packet_type + header + body
|
return packet_type + header + body
|
||||||
|
|
||||||
@@ -211,7 +220,7 @@ class AsyncPacketReader:
|
|||||||
# Get the packet info based on its type
|
# Get the packet info based on its type
|
||||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||||
if packet_info is None:
|
if packet_info is None:
|
||||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||||
|
|
||||||
# Read the header (that includes the length)
|
# Read the header (that includes the length)
|
||||||
header_size = packet_info[0] + packet_info[1]
|
header_size = packet_info[0] + packet_info[1]
|
||||||
@@ -239,26 +248,28 @@ class AsyncPipeSink:
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class ParserSource:
|
class BaseSource:
|
||||||
"""
|
"""
|
||||||
Base class designed to be subclassed by transport-specific source classes
|
Base class designed to be subclassed by transport-specific source classes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
terminated: asyncio.Future[None]
|
terminated: asyncio.Future[None]
|
||||||
parser: PacketParser
|
sink: Optional[TransportSink]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.parser = PacketParser()
|
|
||||||
self.terminated = asyncio.get_running_loop().create_future()
|
self.terminated = asyncio.get_running_loop().create_future()
|
||||||
|
self.sink = None
|
||||||
|
|
||||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||||
self.parser.set_packet_sink(sink)
|
self.sink = sink
|
||||||
|
|
||||||
def on_transport_lost(self) -> None:
|
def on_transport_lost(self) -> None:
|
||||||
self.terminated.set_result(None)
|
if not self.terminated.done():
|
||||||
if self.parser.sink:
|
self.terminated.set_result(None)
|
||||||
if hasattr(self.parser.sink, 'on_transport_lost'):
|
|
||||||
self.parser.sink.on_transport_lost()
|
if self.sink:
|
||||||
|
if hasattr(self.sink, 'on_transport_lost'):
|
||||||
|
self.sink.on_transport_lost()
|
||||||
|
|
||||||
async def wait_for_termination(self) -> None:
|
async def wait_for_termination(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -271,6 +282,23 @@ class ParserSource:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class ParserSource(BaseSource):
|
||||||
|
"""
|
||||||
|
Base class for sources that use an HCI parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parser: PacketParser
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.parser = PacketParser()
|
||||||
|
|
||||||
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||||
|
super().set_packet_sink(sink)
|
||||||
|
self.parser.set_packet_sink(sink)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
||||||
def data_received(self, data: bytes) -> None:
|
def data_received(self, data: bytes) -> None:
|
||||||
@@ -420,7 +448,7 @@ class SnoopingTransport(Transport):
|
|||||||
return SnoopingTransport(
|
return SnoopingTransport(
|
||||||
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
||||||
)
|
)
|
||||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
raise core.UnreachableError() # Satisfy the type checker
|
||||||
|
|
||||||
class Source:
|
class Source:
|
||||||
sink: TransportSink
|
sink: TransportSink
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from usb.core import USBError
|
|||||||
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
|
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
|
||||||
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
|
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
|
||||||
|
|
||||||
from .common import Transport, ParserSource
|
from .common import Transport, ParserSource, TransportInitError
|
||||||
from .. import hci
|
from .. import hci
|
||||||
from ..colors import color
|
from ..colors import color
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
|||||||
device = None
|
device = None
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
raise ValueError('device not found')
|
raise TransportInitError('device not found')
|
||||||
logger.debug(f'USB Device: {device}')
|
logger.debug(f'USB Device: {device}')
|
||||||
|
|
||||||
# Power Cycle the device
|
# Power Cycle the device
|
||||||
|
|||||||
56
bumble/transport/unix.py
Normal file
56
bumble/transport/unix.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2021-2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .common import Transport, StreamPacketSource, StreamPacketSink
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def open_unix_client_transport(spec: str) -> Transport:
|
||||||
|
'''Open a UNIX socket client transport.
|
||||||
|
|
||||||
|
The parameter is the path of unix socket. For abstract socket, the first character
|
||||||
|
needs to be '@'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
* /tmp/hci.socket
|
||||||
|
* @hci_socket
|
||||||
|
'''
|
||||||
|
|
||||||
|
class UnixPacketSource(StreamPacketSource):
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
logger.debug(f'connection lost: {exc}')
|
||||||
|
self.on_transport_lost()
|
||||||
|
|
||||||
|
# For abstract socket, the first character should be null character.
|
||||||
|
if spec.startswith('@'):
|
||||||
|
spec = '\0' + spec[1:]
|
||||||
|
|
||||||
|
(
|
||||||
|
unix_transport,
|
||||||
|
packet_source,
|
||||||
|
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
|
||||||
|
packet_sink = StreamPacketSink(unix_transport)
|
||||||
|
|
||||||
|
return Transport(packet_source, packet_sink)
|
||||||
@@ -15,19 +15,18 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import collections
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import usb1
|
import usb1
|
||||||
|
|
||||||
from bumble.transport.common import Transport, ParserSource
|
from bumble.transport.common import Transport, BaseSource, TransportInitError
|
||||||
from bumble import hci
|
from bumble import hci
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.utils import AsyncRunner
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -115,13 +114,17 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.acl_out = acl_out
|
self.acl_out = acl_out
|
||||||
self.acl_out_transfer = device.getTransfer()
|
self.acl_out_transfer = device.getTransfer()
|
||||||
self.packets = collections.deque() # Queue of packets waiting to be sent
|
self.acl_out_transfer_ready = asyncio.Semaphore(1)
|
||||||
|
self.packets: asyncio.Queue[bytes] = (
|
||||||
|
asyncio.Queue()
|
||||||
|
) # Queue of packets waiting to be sent
|
||||||
self.loop = asyncio.get_running_loop()
|
self.loop = asyncio.get_running_loop()
|
||||||
|
self.queue_task = None
|
||||||
self.cancel_done = self.loop.create_future()
|
self.cancel_done = self.loop.create_future()
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
pass
|
self.queue_task = asyncio.create_task(self.process_queue())
|
||||||
|
|
||||||
def on_packet(self, packet):
|
def on_packet(self, packet):
|
||||||
# Ignore packets if we're closed
|
# Ignore packets if we're closed
|
||||||
@@ -133,62 +136,64 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Queue the packet
|
# Queue the packet
|
||||||
self.packets.append(packet)
|
self.packets.put_nowait(packet)
|
||||||
if len(self.packets) == 1:
|
|
||||||
# The queue was previously empty, re-prime the pump
|
|
||||||
self.process_queue()
|
|
||||||
|
|
||||||
def transfer_callback(self, transfer):
|
def transfer_callback(self, transfer):
|
||||||
|
self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
|
||||||
status = transfer.getStatus()
|
status = transfer.getStatus()
|
||||||
|
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
if status == usb1.TRANSFER_COMPLETED:
|
if status == usb1.TRANSFER_CANCELLED:
|
||||||
self.loop.call_soon_threadsafe(self.on_packet_sent)
|
|
||||||
elif status == usb1.TRANSFER_CANCELLED:
|
|
||||||
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
|
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
|
||||||
else:
|
return
|
||||||
|
|
||||||
|
if status != usb1.TRANSFER_COMPLETED:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_packet_sent(self):
|
async def process_queue(self):
|
||||||
if self.packets:
|
while True:
|
||||||
self.packets.popleft()
|
# Wait for a packet to transfer.
|
||||||
self.process_queue()
|
packet = await self.packets.get()
|
||||||
|
|
||||||
def process_queue(self):
|
# Wait until we can start a transfer.
|
||||||
if len(self.packets) == 0:
|
await self.acl_out_transfer_ready.acquire()
|
||||||
return # Nothing to do
|
|
||||||
|
|
||||||
packet = self.packets[0]
|
# Transfer the packet.
|
||||||
packet_type = packet[0]
|
packet_type = packet[0]
|
||||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||||
self.acl_out_transfer.setBulk(
|
self.acl_out_transfer.setBulk(
|
||||||
self.acl_out, packet[1:], callback=self.transfer_callback
|
self.acl_out, packet[1:], callback=self.transfer_callback
|
||||||
)
|
)
|
||||||
self.acl_out_transfer.submit()
|
self.acl_out_transfer.submit()
|
||||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||||
self.acl_out_transfer.setControl(
|
self.acl_out_transfer.setControl(
|
||||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
packet[1:],
|
packet[1:],
|
||||||
callback=self.transfer_callback,
|
callback=self.transfer_callback,
|
||||||
)
|
)
|
||||||
self.acl_out_transfer.submit()
|
self.acl_out_transfer.submit()
|
||||||
else:
|
else:
|
||||||
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
|
logger.warning(
|
||||||
|
color(f'unsupported packet type {packet_type}', 'red')
|
||||||
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
if self.queue_task:
|
||||||
|
self.queue_task.cancel()
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
# Empty the packet queue so that we don't send any more data
|
# Empty the packet queue so that we don't send any more data
|
||||||
self.packets.clear()
|
while not self.packets.empty():
|
||||||
|
self.packets.get_nowait()
|
||||||
|
|
||||||
# If we have a transfer in flight, cancel it
|
# If we have a transfer in flight, cancel it
|
||||||
if self.acl_out_transfer.isSubmitted():
|
if self.acl_out_transfer.isSubmitted():
|
||||||
@@ -203,7 +208,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
except usb1.USBError:
|
except usb1.USBError:
|
||||||
logger.debug('OUT transfer likely already completed')
|
logger.debug('OUT transfer likely already completed')
|
||||||
|
|
||||||
class UsbPacketSource(asyncio.Protocol, ParserSource):
|
class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||||
def __init__(self, device, metadata, acl_in, events_in):
|
def __init__(self, device, metadata, acl_in, events_in):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -280,7 +285,13 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
packet = await self.queue.get()
|
packet = await self.queue.get()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return
|
return
|
||||||
self.parser.feed_data(packet)
|
if self.sink:
|
||||||
|
try:
|
||||||
|
self.sink.on_packet(packet)
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception(
|
||||||
|
color(f'!!! Exception in sink.on_packet: {error}', 'red')
|
||||||
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.closed = True
|
self.closed = True
|
||||||
@@ -442,7 +453,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
|
|
||||||
if found is None:
|
if found is None:
|
||||||
context.close()
|
context.close()
|
||||||
raise ValueError('device not found')
|
raise TransportInitError('device not found')
|
||||||
|
|
||||||
logger.debug(f'USB Device: {found}')
|
logger.debug(f'USB Device: {found}')
|
||||||
|
|
||||||
@@ -507,7 +518,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
|||||||
|
|
||||||
endpoints = find_endpoints(found)
|
endpoints = find_endpoints(found)
|
||||||
if endpoints is None:
|
if endpoints is None:
|
||||||
raise ValueError('no compatible interface found for device')
|
raise TransportInitError('no compatible interface found for device')
|
||||||
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
|
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'selected endpoints: configuration={configuration}, '
|
f'selected endpoints: configuration={configuration}, '
|
||||||
|
|||||||
BIN
docs/images/favicon.ico
Normal file
BIN
docs/images/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
7
examples/device_with_rpa.json
Normal file
7
examples/device_with_rpa.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"name": "Bumble",
|
||||||
|
"address": "F0:F1:F2:F3:F4:F5",
|
||||||
|
"keystore": "JsonKeyStore",
|
||||||
|
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
|
||||||
|
"le_privacy_enabled": true
|
||||||
|
}
|
||||||
@@ -3,5 +3,6 @@
|
|||||||
"keystore": "JsonKeyStore",
|
"keystore": "JsonKeyStore",
|
||||||
"address": "F0:F1:F2:F3:F4:FA",
|
"address": "F0:F1:F2:F3:F4:FA",
|
||||||
"class_of_device": 2376708,
|
"class_of_device": 2376708,
|
||||||
|
"cis_enabled": true,
|
||||||
"advertising_interval": 100
|
"advertising_interval": 100
|
||||||
}
|
}
|
||||||
|
|||||||
83
examples/mcp_server.html
Normal file
83
examples/mcp_server.html
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
<html data-bs-theme="dark">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
|
||||||
|
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<nav class="navbar navbar-dark bg-primary">
|
||||||
|
<div class="container">
|
||||||
|
<span class="navbar-brand mb-0 h1">Bumble LEA Media Control Client</span>
|
||||||
|
</div>
|
||||||
|
</nav>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<div class="container">
|
||||||
|
|
||||||
|
<label class="form-label">Server Port</label>
|
||||||
|
<div class="input-group mb-3">
|
||||||
|
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
|
||||||
|
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x01)">Play</button>
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x02)">Pause</button>
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x03)">Fast Rewind</button>
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x04)">Fast Forward</button>
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x05)">Stop</button>
|
||||||
|
|
||||||
|
</br></br>
|
||||||
|
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x30)">Previous Track</button>
|
||||||
|
<button class="btn btn-primary" onclick="send_opcode(0x31)">Next Track</button>
|
||||||
|
|
||||||
|
<hr>
|
||||||
|
|
||||||
|
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
|
||||||
|
<h3>Log</h3>
|
||||||
|
<code id="log" style="white-space: pre-line;"></code>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let portInput = document.getElementById("port")
|
||||||
|
let log = document.getElementById("log")
|
||||||
|
let socket
|
||||||
|
|
||||||
|
function connect() {
|
||||||
|
socket = new WebSocket(`ws://localhost:${portInput.value}`);
|
||||||
|
socket.onopen = _ => {
|
||||||
|
log.textContent += 'OPEN\n'
|
||||||
|
}
|
||||||
|
socket.onclose = _ => {
|
||||||
|
log.textContent += 'CLOSED\n'
|
||||||
|
}
|
||||||
|
socket.onerror = (error) => {
|
||||||
|
log.textContent += 'ERROR\n'
|
||||||
|
console.log(`ERROR: ${error}`)
|
||||||
|
}
|
||||||
|
socket.onmessage = (event) => {
|
||||||
|
log.textContent += `<-- ${event.data}\n`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function send(message) {
|
||||||
|
if (socket && socket.readyState == WebSocket.OPEN) {
|
||||||
|
let jsonMessage = JSON.stringify(message)
|
||||||
|
log.textContent += `--> ${jsonMessage}\n`
|
||||||
|
socket.send(jsonMessage)
|
||||||
|
} else {
|
||||||
|
log.textContent += 'NOT CONNECTED\n'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function send_opcode(opcode) {
|
||||||
|
send({ 'opcode': opcode })
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
@@ -21,7 +21,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import websockets
|
import websockets
|
||||||
from bumble.colors import color
|
import struct
|
||||||
|
|
||||||
from bumble.device import Device
|
from bumble.device import Device
|
||||||
from bumble.transport import open_transport_or_link
|
from bumble.transport import open_transport_or_link
|
||||||
@@ -30,9 +30,7 @@ from bumble.core import (
|
|||||||
BT_L2CAP_PROTOCOL_ID,
|
BT_L2CAP_PROTOCOL_ID,
|
||||||
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
||||||
BT_HIDP_PROTOCOL_ID,
|
BT_HIDP_PROTOCOL_ID,
|
||||||
UUID,
|
|
||||||
)
|
)
|
||||||
from bumble.hci import Address
|
|
||||||
from bumble.hid import (
|
from bumble.hid import (
|
||||||
Device as HID_Device,
|
Device as HID_Device,
|
||||||
HID_CONTROL_PSM,
|
HID_CONTROL_PSM,
|
||||||
@@ -40,20 +38,17 @@ from bumble.hid import (
|
|||||||
Message,
|
Message,
|
||||||
)
|
)
|
||||||
from bumble.sdp import (
|
from bumble.sdp import (
|
||||||
Client as SDP_Client,
|
|
||||||
DataElement,
|
DataElement,
|
||||||
ServiceAttribute,
|
ServiceAttribute,
|
||||||
SDP_PUBLIC_BROWSE_ROOT,
|
SDP_PUBLIC_BROWSE_ROOT,
|
||||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_ALL_ATTRIBUTES_RANGE,
|
|
||||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
)
|
)
|
||||||
from bumble.utils import AsyncRunner
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# SDP attributes for Bluetooth HID devices
|
# SDP attributes for Bluetooth HID devices
|
||||||
@@ -430,7 +425,7 @@ deviceData = DeviceData()
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def keyboard_device(hid_device):
|
async def keyboard_device(hid_device: HID_Device):
|
||||||
|
|
||||||
# Start a Websocket server to receive events from a web page
|
# Start a Websocket server to receive events from a web page
|
||||||
async def serve(websocket, _path):
|
async def serve(websocket, _path):
|
||||||
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
|
|||||||
# limiting x and y values within logical max and min range
|
# limiting x and y values within logical max and min range
|
||||||
x = max(log_min, min(log_max, x))
|
x = max(log_min, min(log_max, x))
|
||||||
y = max(log_min, min(log_max, y))
|
y = max(log_min, min(log_max, y))
|
||||||
x_cord = x.to_bytes(signed=True)
|
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
|
||||||
y_cord = y.to_bytes(signed=True)
|
">bb", x, y
|
||||||
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
|
)
|
||||||
hid_device.send_data(deviceData.mouseData)
|
hid_device.send_data(deviceData.mouseData)
|
||||||
except websockets.exceptions.ConnectionClosedOK:
|
except websockets.exceptions.ConnectionClosedOK:
|
||||||
pass
|
pass
|
||||||
@@ -515,7 +510,9 @@ async def main() -> None:
|
|||||||
def on_hid_data_cb(pdu: bytes):
|
def on_hid_data_cb(pdu: bytes):
|
||||||
print(f'Received Data, PDU: {pdu.hex()}')
|
print(f'Received Data, PDU: {pdu.hex()}')
|
||||||
|
|
||||||
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
|
def on_get_report_cb(
|
||||||
|
report_id: int, report_type: int, buffer_size: int
|
||||||
|
) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
retValue = hid_device.GetSetStatus()
|
||||||
print(
|
print(
|
||||||
"GET_REPORT report_id: "
|
"GET_REPORT report_id: "
|
||||||
@@ -555,8 +552,7 @@ async def main() -> None:
|
|||||||
|
|
||||||
def on_set_report_cb(
|
def on_set_report_cb(
|
||||||
report_id: int, report_type: int, report_size: int, data: bytes
|
report_id: int, report_type: int, report_size: int, data: bytes
|
||||||
):
|
) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
|
||||||
print(
|
print(
|
||||||
"SET_REPORT report_id: "
|
"SET_REPORT report_id: "
|
||||||
+ str(report_id)
|
+ str(report_id)
|
||||||
@@ -568,33 +564,33 @@ async def main() -> None:
|
|||||||
+ str(data)
|
+ str(data)
|
||||||
)
|
)
|
||||||
if report_type == Message.ReportType.FEATURE_REPORT:
|
if report_type == Message.ReportType.FEATURE_REPORT:
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_type == Message.ReportType.INPUT_REPORT:
|
elif report_type == Message.ReportType.INPUT_REPORT:
|
||||||
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||||
elif report_id == 3:
|
elif report_id == 3:
|
||||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||||
else:
|
else:
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status = HID_Device.GetSetReturn.SUCCESS
|
||||||
else:
|
else:
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status = HID_Device.GetSetReturn.SUCCESS
|
||||||
|
|
||||||
return retValue
|
return HID_Device.GetSetStatus(status=status)
|
||||||
|
|
||||||
def on_get_protocol_cb():
|
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
return HID_Device.GetSetStatus(
|
||||||
retValue.data = protocol_mode.to_bytes()
|
data=bytes([protocol_mode]),
|
||||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
status=hid_device.GetSetReturn.SUCCESS,
|
||||||
return retValue
|
)
|
||||||
|
|
||||||
def on_set_protocol_cb(protocol: int):
|
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
|
||||||
retValue = hid_device.GetSetStatus()
|
|
||||||
# We do not support SET_PROTOCOL.
|
# We do not support SET_PROTOCOL.
|
||||||
print(f"SET_PROTOCOL report_id: {protocol}")
|
print(f"SET_PROTOCOL report_id: {protocol}")
|
||||||
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
return HID_Device.GetSetStatus(
|
||||||
return retValue
|
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
def on_virtual_cable_unplug_cb():
|
def on_virtual_cable_unplug_cb():
|
||||||
print('Received Virtual Cable Unplug')
|
print('Received Virtual Cable Unplug')
|
||||||
|
|||||||
194
examples/run_mcp_client.py
Normal file
194
examples/run_mcp_client.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# Copyright 2021-2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import websockets
|
||||||
|
import json
|
||||||
|
|
||||||
|
from bumble.core import AdvertisingData
|
||||||
|
from bumble.device import (
|
||||||
|
Device,
|
||||||
|
AdvertisingParameters,
|
||||||
|
AdvertisingEventProperties,
|
||||||
|
Connection,
|
||||||
|
Peer,
|
||||||
|
)
|
||||||
|
from bumble.hci import (
|
||||||
|
CodecID,
|
||||||
|
CodingFormat,
|
||||||
|
OwnAddressType,
|
||||||
|
)
|
||||||
|
from bumble.profiles.ascs import AudioStreamControlService
|
||||||
|
from bumble.profiles.bap import (
|
||||||
|
CodecSpecificCapabilities,
|
||||||
|
ContextType,
|
||||||
|
AudioLocation,
|
||||||
|
SupportedSamplingFrequency,
|
||||||
|
SupportedFrameDuration,
|
||||||
|
UnicastServerAdvertisingData,
|
||||||
|
)
|
||||||
|
from bumble.profiles.mcp import (
|
||||||
|
MediaControlServiceProxy,
|
||||||
|
GenericMediaControlServiceProxy,
|
||||||
|
MediaState,
|
||||||
|
MediaControlPointOpcode,
|
||||||
|
)
|
||||||
|
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
|
||||||
|
from bumble.transport import open_transport_or_link
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def main() -> None:
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
print('Usage: run_mcp_client.py <config-file>' '<transport-spec-for-device>')
|
||||||
|
return
|
||||||
|
|
||||||
|
print('<<< connecting to HCI...')
|
||||||
|
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
|
||||||
|
print('<<< connected')
|
||||||
|
|
||||||
|
device = Device.from_config_file_with_hci(
|
||||||
|
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||||
|
)
|
||||||
|
|
||||||
|
await device.power_on()
|
||||||
|
|
||||||
|
# Add "placeholder" services to enable Android LEA features.
|
||||||
|
device.add_service(
|
||||||
|
PublishedAudioCapabilitiesService(
|
||||||
|
supported_source_context=ContextType.PROHIBITED,
|
||||||
|
available_source_context=ContextType.PROHIBITED,
|
||||||
|
supported_sink_context=ContextType.MEDIA,
|
||||||
|
available_sink_context=ContextType.MEDIA,
|
||||||
|
sink_audio_locations=(
|
||||||
|
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
|
||||||
|
),
|
||||||
|
sink_pac=[
|
||||||
|
PacRecord(
|
||||||
|
coding_format=CodingFormat(CodecID.LC3),
|
||||||
|
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||||
|
supported_sampling_frequencies=(
|
||||||
|
SupportedSamplingFrequency.FREQ_16000
|
||||||
|
| SupportedSamplingFrequency.FREQ_32000
|
||||||
|
| SupportedSamplingFrequency.FREQ_48000
|
||||||
|
),
|
||||||
|
supported_frame_durations=(
|
||||||
|
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||||
|
),
|
||||||
|
supported_audio_channel_count=[1, 2],
|
||||||
|
min_octets_per_codec_frame=0,
|
||||||
|
max_octets_per_codec_frame=320,
|
||||||
|
supported_max_codec_frames_per_sdu=2,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device.add_service(AudioStreamControlService(device, sink_ase_id=[1]))
|
||||||
|
|
||||||
|
ws: Optional[websockets.WebSocketServerProtocol] = None
|
||||||
|
mcp: Optional[MediaControlServiceProxy] = None
|
||||||
|
|
||||||
|
advertising_data = bytes(
|
||||||
|
AdvertisingData(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||||
|
bytes('Bumble LE Audio', 'utf-8'),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
AdvertisingData.FLAGS,
|
||||||
|
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||||
|
bytes(PublishedAudioCapabilitiesService.UUID),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
) + bytes(UnicastServerAdvertisingData())
|
||||||
|
|
||||||
|
await device.create_advertising_set(
|
||||||
|
advertising_parameters=AdvertisingParameters(
|
||||||
|
advertising_event_properties=AdvertisingEventProperties(),
|
||||||
|
own_address_type=OwnAddressType.RANDOM,
|
||||||
|
primary_advertising_interval_max=100,
|
||||||
|
primary_advertising_interval_min=100,
|
||||||
|
),
|
||||||
|
advertising_data=advertising_data,
|
||||||
|
auto_restart=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_media_state(media_state: MediaState) -> None:
|
||||||
|
if ws:
|
||||||
|
asyncio.create_task(
|
||||||
|
ws.send(json.dumps({'media_state': media_state.name}))
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_track_title(title: str) -> None:
|
||||||
|
if ws:
|
||||||
|
asyncio.create_task(ws.send(json.dumps({'title': title})))
|
||||||
|
|
||||||
|
def on_track_duration(duration: int) -> None:
|
||||||
|
if ws:
|
||||||
|
asyncio.create_task(ws.send(json.dumps({'duration': duration})))
|
||||||
|
|
||||||
|
def on_track_position(position: int) -> None:
|
||||||
|
if ws:
|
||||||
|
asyncio.create_task(ws.send(json.dumps({'position': position})))
|
||||||
|
|
||||||
|
def on_connection(connection: Connection) -> None:
|
||||||
|
async def on_connection_async():
|
||||||
|
async with Peer(connection) as peer:
|
||||||
|
nonlocal mcp
|
||||||
|
mcp = peer.create_service_proxy(MediaControlServiceProxy)
|
||||||
|
if not mcp:
|
||||||
|
mcp = peer.create_service_proxy(GenericMediaControlServiceProxy)
|
||||||
|
mcp.on('media_state', on_media_state)
|
||||||
|
mcp.on('track_title', on_track_title)
|
||||||
|
mcp.on('track_duration', on_track_duration)
|
||||||
|
mcp.on('track_position', on_track_position)
|
||||||
|
await mcp.subscribe_characteristics()
|
||||||
|
|
||||||
|
connection.abort_on('disconnection', on_connection_async())
|
||||||
|
|
||||||
|
device.on('connection', on_connection)
|
||||||
|
|
||||||
|
async def serve(websocket: websockets.WebSocketServerProtocol, _path):
|
||||||
|
nonlocal ws
|
||||||
|
ws = websocket
|
||||||
|
async for message in websocket:
|
||||||
|
request = json.loads(message)
|
||||||
|
if mcp:
|
||||||
|
await mcp.write_control_point(
|
||||||
|
MediaControlPointOpcode(request['opcode'])
|
||||||
|
)
|
||||||
|
ws = None
|
||||||
|
|
||||||
|
await websockets.serve(serve, 'localhost', 8989)
|
||||||
|
|
||||||
|
await hci_transport.source.terminated
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||||
|
asyncio.run(main())
|
||||||
@@ -34,8 +34,8 @@ from bumble.hci import (
|
|||||||
CodingFormat,
|
CodingFormat,
|
||||||
HCI_IsoDataPacket,
|
HCI_IsoDataPacket,
|
||||||
)
|
)
|
||||||
|
from bumble.profiles.ascs import AseStateMachine, AudioStreamControlService
|
||||||
from bumble.profiles.bap import (
|
from bumble.profiles.bap import (
|
||||||
AseStateMachine,
|
|
||||||
UnicastServerAdvertisingData,
|
UnicastServerAdvertisingData,
|
||||||
CodecSpecificConfiguration,
|
CodecSpecificConfiguration,
|
||||||
CodecSpecificCapabilities,
|
CodecSpecificCapabilities,
|
||||||
@@ -43,13 +43,10 @@ from bumble.profiles.bap import (
|
|||||||
AudioLocation,
|
AudioLocation,
|
||||||
SupportedSamplingFrequency,
|
SupportedSamplingFrequency,
|
||||||
SupportedFrameDuration,
|
SupportedFrameDuration,
|
||||||
PacRecord,
|
|
||||||
PublishedAudioCapabilitiesService,
|
|
||||||
AudioStreamControlService,
|
|
||||||
)
|
)
|
||||||
from bumble.profiles.cap import CommonAudioServiceService
|
from bumble.profiles.cap import CommonAudioServiceService
|
||||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||||
|
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
|
||||||
from bumble.transport import open_transport_or_link
|
from bumble.transport import open_transport_or_link
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from bumble.hci import (
|
|||||||
CodingFormat,
|
CodingFormat,
|
||||||
OwnAddressType,
|
OwnAddressType,
|
||||||
)
|
)
|
||||||
|
from bumble.profiles.ascs import AudioStreamControlService
|
||||||
from bumble.profiles.bap import (
|
from bumble.profiles.bap import (
|
||||||
UnicastServerAdvertisingData,
|
UnicastServerAdvertisingData,
|
||||||
CodecSpecificCapabilities,
|
CodecSpecificCapabilities,
|
||||||
@@ -37,10 +38,8 @@ from bumble.profiles.bap import (
|
|||||||
AudioLocation,
|
AudioLocation,
|
||||||
SupportedSamplingFrequency,
|
SupportedSamplingFrequency,
|
||||||
SupportedFrameDuration,
|
SupportedFrameDuration,
|
||||||
PacRecord,
|
|
||||||
PublishedAudioCapabilitiesService,
|
|
||||||
AudioStreamControlService,
|
|
||||||
)
|
)
|
||||||
|
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
|
||||||
from bumble.profiles.cap import CommonAudioServiceService
|
from bumble.profiles.cap import CommonAudioServiceService
|
||||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||||
from bumble.profiles.vcp import VolumeControlService
|
from bumble.profiles.vcp import VolumeControlService
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class MainActivity : ComponentActivity() {
|
|||||||
::runRfcommClient,
|
::runRfcommClient,
|
||||||
::runRfcommServer,
|
::runRfcommServer,
|
||||||
::runL2capClient,
|
::runL2capClient,
|
||||||
::runL2capServer
|
::runL2capServer,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +166,8 @@ class MainActivity : ComponentActivity() {
|
|||||||
"rfcomm-server" -> runRfcommServer()
|
"rfcomm-server" -> runRfcommServer()
|
||||||
"l2cap-client" -> runL2capClient()
|
"l2cap-client" -> runL2capClient()
|
||||||
"l2cap-server" -> runL2capServer()
|
"l2cap-server" -> runL2capServer()
|
||||||
|
"scan-start" -> runScan(true)
|
||||||
|
"stop-start" -> runScan(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -190,6 +192,11 @@ class MainActivity : ComponentActivity() {
|
|||||||
l2capServer?.run()
|
l2capServer?.run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun runScan(startScan: Boolean) {
|
||||||
|
val scan = bluetoothAdapter?.let { Scan(it) }
|
||||||
|
scan?.run(startScan)
|
||||||
|
}
|
||||||
|
|
||||||
@SuppressLint("MissingPermission")
|
@SuppressLint("MissingPermission")
|
||||||
fun becomeDiscoverable() {
|
fun becomeDiscoverable() {
|
||||||
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
|
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
|
||||||
@@ -206,7 +213,7 @@ fun MainView(
|
|||||||
runRfcommClient: () -> Unit,
|
runRfcommClient: () -> Unit,
|
||||||
runRfcommServer: () -> Unit,
|
runRfcommServer: () -> Unit,
|
||||||
runL2capClient: () -> Unit,
|
runL2capClient: () -> Unit,
|
||||||
runL2capServer: () -> Unit
|
runL2capServer: () -> Unit,
|
||||||
) {
|
) {
|
||||||
BTBenchTheme {
|
BTBenchTheme {
|
||||||
val scrollState = rememberScrollState()
|
val scrollState = rememberScrollState()
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package com.github.google.bumble.btbench
|
||||||
|
|
||||||
|
import android.annotation.SuppressLint
|
||||||
|
import android.bluetooth.BluetoothAdapter
|
||||||
|
import android.bluetooth.BluetoothDevice
|
||||||
|
import android.bluetooth.le.ScanCallback
|
||||||
|
import android.bluetooth.le.ScanResult
|
||||||
|
import java.util.logging.Logger
|
||||||
|
|
||||||
|
private val Log = Logger.getLogger("btbench.scan")
|
||||||
|
|
||||||
|
class Scan(val bluetoothAdapter: BluetoothAdapter) {
|
||||||
|
@SuppressLint("MissingPermission")
|
||||||
|
fun run(startScan: Boolean) {
|
||||||
|
var bluetoothLeScanner = bluetoothAdapter.bluetoothLeScanner
|
||||||
|
|
||||||
|
val scanCallback = object : ScanCallback() {
|
||||||
|
override fun onScanResult(callbackType: Int, result: ScanResult?) {
|
||||||
|
super.onScanResult(callbackType, result)
|
||||||
|
val device: BluetoothDevice? = result?.device
|
||||||
|
val deviceName = device?.name ?: "Unknown"
|
||||||
|
val deviceAddress = device?.address ?: "Unknown"
|
||||||
|
Log.info("Device found: $deviceName ($deviceAddress)")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onScanFailed(errorCode: Int) {
|
||||||
|
// Handle scan failure
|
||||||
|
Log.warning("Scan failed with error code: $errorCode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (startScan) {
|
||||||
|
bluetoothLeScanner?.startScan(scanCallback)
|
||||||
|
} else {
|
||||||
|
bluetoothLeScanner?.stopScan(scanCallback)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,8 +23,9 @@ import logging
|
|||||||
|
|
||||||
from bumble import device
|
from bumble import device
|
||||||
from bumble.hci import CodecID, CodingFormat
|
from bumble.hci import CodecID, CodingFormat
|
||||||
from bumble.profiles.bap import (
|
from bumble.profiles.ascs import (
|
||||||
AudioLocation,
|
AudioStreamControlService,
|
||||||
|
AudioStreamControlServiceProxy,
|
||||||
AseStateMachine,
|
AseStateMachine,
|
||||||
ASE_Operation,
|
ASE_Operation,
|
||||||
ASE_Config_Codec,
|
ASE_Config_Codec,
|
||||||
@@ -35,6 +36,9 @@ from bumble.profiles.bap import (
|
|||||||
ASE_Receiver_Stop_Ready,
|
ASE_Receiver_Stop_Ready,
|
||||||
ASE_Release,
|
ASE_Release,
|
||||||
ASE_Update_Metadata,
|
ASE_Update_Metadata,
|
||||||
|
)
|
||||||
|
from bumble.profiles.bap import (
|
||||||
|
AudioLocation,
|
||||||
SupportedFrameDuration,
|
SupportedFrameDuration,
|
||||||
SupportedSamplingFrequency,
|
SupportedSamplingFrequency,
|
||||||
SamplingFrequency,
|
SamplingFrequency,
|
||||||
@@ -42,12 +46,13 @@ from bumble.profiles.bap import (
|
|||||||
CodecSpecificCapabilities,
|
CodecSpecificCapabilities,
|
||||||
CodecSpecificConfiguration,
|
CodecSpecificConfiguration,
|
||||||
ContextType,
|
ContextType,
|
||||||
|
)
|
||||||
|
from bumble.profiles.pacs import (
|
||||||
PacRecord,
|
PacRecord,
|
||||||
AudioStreamControlService,
|
|
||||||
AudioStreamControlServiceProxy,
|
|
||||||
PublishedAudioCapabilitiesService,
|
PublishedAudioCapabilitiesService,
|
||||||
PublishedAudioCapabilitiesServiceProxy,
|
PublishedAudioCapabilitiesServiceProxy,
|
||||||
)
|
)
|
||||||
|
from bumble.profiles.le_audio import Metadata
|
||||||
from tests.test_utils import TwoDevices
|
from tests.test_utils import TwoDevices
|
||||||
|
|
||||||
|
|
||||||
@@ -97,7 +102,7 @@ def test_pac_record() -> None:
|
|||||||
pac_record = PacRecord(
|
pac_record = PacRecord(
|
||||||
coding_format=CodingFormat(CodecID.LC3),
|
coding_format=CodingFormat(CodecID.LC3),
|
||||||
codec_specific_capabilities=cap,
|
codec_specific_capabilities=cap,
|
||||||
metadata=b'',
|
metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
|
||||||
)
|
)
|
||||||
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
|
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
|
||||||
|
|
||||||
@@ -142,7 +147,7 @@ def test_ASE_Config_QOS() -> None:
|
|||||||
def test_ASE_Enable() -> None:
|
def test_ASE_Enable() -> None:
|
||||||
operation = ASE_Enable(
|
operation = ASE_Enable(
|
||||||
ase_id=[1, 2],
|
ase_id=[1, 2],
|
||||||
metadata=[b'foo', b'bar'],
|
metadata=[b'', b''],
|
||||||
)
|
)
|
||||||
basic_check(operation)
|
basic_check(operation)
|
||||||
|
|
||||||
@@ -151,7 +156,7 @@ def test_ASE_Enable() -> None:
|
|||||||
def test_ASE_Update_Metadata() -> None:
|
def test_ASE_Update_Metadata() -> None:
|
||||||
operation = ASE_Update_Metadata(
|
operation = ASE_Update_Metadata(
|
||||||
ase_id=[1, 2],
|
ase_id=[1, 2],
|
||||||
metadata=[b'foo', b'bar'],
|
metadata=[b'', b''],
|
||||||
)
|
)
|
||||||
basic_check(operation)
|
basic_check(operation)
|
||||||
|
|
||||||
|
|||||||
146
tests/bass_test.py
Normal file
146
tests/bass_test.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from bumble import hci
|
||||||
|
from bumble.profiles import bass
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def basic_operation_check(operation: bass.ControlPointOperation) -> None:
|
||||||
|
serialized = bytes(operation)
|
||||||
|
parsed = bass.ControlPointOperation.from_bytes(serialized)
|
||||||
|
assert bytes(parsed) == serialized
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def test_operations() -> None:
|
||||||
|
op1 = bass.RemoteScanStoppedOperation()
|
||||||
|
basic_operation_check(op1)
|
||||||
|
|
||||||
|
op2 = bass.RemoteScanStartedOperation()
|
||||||
|
basic_operation_check(op2)
|
||||||
|
|
||||||
|
op3 = bass.AddSourceOperation(
|
||||||
|
hci.Address("AA:BB:CC:DD:EE:FF"),
|
||||||
|
34,
|
||||||
|
123456,
|
||||||
|
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||||
|
456,
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
basic_operation_check(op3)
|
||||||
|
|
||||||
|
op4 = bass.AddSourceOperation(
|
||||||
|
hci.Address("AA:BB:CC:DD:EE:FF"),
|
||||||
|
34,
|
||||||
|
123456,
|
||||||
|
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||||
|
456,
|
||||||
|
(
|
||||||
|
bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')),
|
||||||
|
bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
basic_operation_check(op4)
|
||||||
|
|
||||||
|
op5 = bass.ModifySourceOperation(
|
||||||
|
12,
|
||||||
|
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||||
|
567,
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
basic_operation_check(op5)
|
||||||
|
|
||||||
|
op6 = bass.ModifySourceOperation(
|
||||||
|
12,
|
||||||
|
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||||
|
567,
|
||||||
|
(
|
||||||
|
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
|
||||||
|
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
basic_operation_check(op6)
|
||||||
|
|
||||||
|
op7 = bass.SetBroadcastCodeOperation(
|
||||||
|
7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf')
|
||||||
|
)
|
||||||
|
basic_operation_check(op7)
|
||||||
|
|
||||||
|
op8 = bass.RemoveSourceOperation(7)
|
||||||
|
basic_operation_check(op8)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None:
|
||||||
|
serialized = bytes(brs)
|
||||||
|
parsed = bass.BroadcastReceiveState.from_bytes(serialized)
|
||||||
|
assert parsed is not None
|
||||||
|
assert bytes(parsed) == serialized
|
||||||
|
|
||||||
|
|
||||||
|
def test_broadcast_receive_state() -> None:
|
||||||
|
subgroups = [
|
||||||
|
bass.SubgroupInfo(6677, bytes.fromhex('112233')),
|
||||||
|
bass.SubgroupInfo(8899, bytes.fromhex('4567')),
|
||||||
|
]
|
||||||
|
|
||||||
|
brs1 = bass.BroadcastReceiveState(
|
||||||
|
12,
|
||||||
|
hci.Address("AA:BB:CC:DD:EE:FF"),
|
||||||
|
123,
|
||||||
|
123456,
|
||||||
|
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
|
||||||
|
bass.BroadcastReceiveState.BigEncryption.DECRYPTING,
|
||||||
|
b'',
|
||||||
|
subgroups,
|
||||||
|
)
|
||||||
|
basic_broadcast_receive_state_check(brs1)
|
||||||
|
|
||||||
|
brs2 = bass.BroadcastReceiveState(
|
||||||
|
12,
|
||||||
|
hci.Address("AA:BB:CC:DD:EE:FF"),
|
||||||
|
123,
|
||||||
|
123456,
|
||||||
|
bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
|
||||||
|
bass.BroadcastReceiveState.BigEncryption.BAD_CODE,
|
||||||
|
bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'),
|
||||||
|
subgroups,
|
||||||
|
)
|
||||||
|
basic_broadcast_receive_state_check(brs2)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def run():
|
||||||
|
test_operations()
|
||||||
|
test_broadcast_receive_state()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
asyncio.run(run())
|
||||||
@@ -276,34 +276,6 @@ async def test_legacy_advertising():
|
|||||||
assert not device.is_advertising
|
assert not device.is_advertising
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
'own_address_type,',
|
|
||||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
|
||||||
)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_legacy_advertising_connection(own_address_type):
|
|
||||||
device = Device(host=mock.AsyncMock(Host))
|
|
||||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
|
||||||
|
|
||||||
# Start advertising
|
|
||||||
await device.start_advertising()
|
|
||||||
device.on_connection(
|
|
||||||
0x0001,
|
|
||||||
BT_LE_TRANSPORT,
|
|
||||||
peer_address,
|
|
||||||
BT_PERIPHERAL_ROLE,
|
|
||||||
ConnectionParameters(0, 0, 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
if own_address_type == OwnAddressType.PUBLIC:
|
|
||||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
|
||||||
else:
|
|
||||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
|
||||||
|
|
||||||
await async_barrier()
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'auto_restart,',
|
'auto_restart,',
|
||||||
@@ -318,6 +290,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
|
|||||||
0x0001,
|
0x0001,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
peer_address,
|
peer_address,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
BT_PERIPHERAL_ROLE,
|
BT_PERIPHERAL_ROLE,
|
||||||
ConnectionParameters(0, 0, 0),
|
ConnectionParameters(0, 0, 0),
|
||||||
)
|
)
|
||||||
@@ -367,6 +341,8 @@ async def test_extended_advertising_connection(own_address_type):
|
|||||||
0x0001,
|
0x0001,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
peer_address,
|
peer_address,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
BT_PERIPHERAL_ROLE,
|
BT_PERIPHERAL_ROLE,
|
||||||
ConnectionParameters(0, 0, 0),
|
ConnectionParameters(0, 0, 0),
|
||||||
)
|
)
|
||||||
@@ -407,6 +383,8 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
|
|||||||
0x0001,
|
0x0001,
|
||||||
BT_LE_TRANSPORT,
|
BT_LE_TRANSPORT,
|
||||||
peer_address,
|
peer_address,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
BT_PERIPHERAL_ROLE,
|
BT_PERIPHERAL_ROLE,
|
||||||
ConnectionParameters(0, 0, 0),
|
ConnectionParameters(0, 0, 0),
|
||||||
)
|
)
|
||||||
@@ -558,6 +536,16 @@ async def test_cis_setup_failure():
|
|||||||
await asyncio.wait_for(cis_create_task, _TIMEOUT)
|
await asyncio.wait_for(cis_create_task, _TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_power_on_default_static_address_should_not_be_any():
|
||||||
|
devices = TwoDevices()
|
||||||
|
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
|
||||||
|
await devices[0].power_on()
|
||||||
|
|
||||||
|
assert devices[0].static_address != Address.ANY_RANDOM
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_gatt_services_with_gas():
|
def test_gatt_services_with_gas():
|
||||||
device = Device(host=Host(None, None))
|
device = Device(host=Host(None, None))
|
||||||
|
|||||||
@@ -879,6 +879,57 @@ async def test_unsubscribe():
|
|||||||
mock1.assert_called_once_with(ANY, False, False)
|
mock1.assert_called_once_with(ANY, False, False)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_all():
|
||||||
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
|
characteristic1 = Characteristic(
|
||||||
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
|
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
bytes([1, 2, 3]),
|
||||||
|
)
|
||||||
|
|
||||||
|
descriptor1 = Descriptor('2902', 'READABLE,WRITEABLE')
|
||||||
|
descriptor2 = Descriptor('AAAA', 'READABLE,WRITEABLE')
|
||||||
|
characteristic2 = Characteristic(
|
||||||
|
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
|
||||||
|
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
bytes([1, 2, 3]),
|
||||||
|
descriptors=[descriptor1, descriptor2],
|
||||||
|
)
|
||||||
|
|
||||||
|
service1 = Service(
|
||||||
|
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
|
||||||
|
[characteristic1, characteristic2],
|
||||||
|
)
|
||||||
|
service2 = Service('1111', [])
|
||||||
|
server.add_services([service1, service2])
|
||||||
|
|
||||||
|
await client.power_on()
|
||||||
|
await server.power_on()
|
||||||
|
connection = await client.connect(server.random_address)
|
||||||
|
peer = Peer(connection)
|
||||||
|
|
||||||
|
await peer.discover_all()
|
||||||
|
assert len(peer.gatt_client.services) == 3
|
||||||
|
# service 1800 gets added automatically
|
||||||
|
assert peer.gatt_client.services[0].uuid == UUID('1800')
|
||||||
|
assert peer.gatt_client.services[1].uuid == service1.uuid
|
||||||
|
assert peer.gatt_client.services[2].uuid == service2.uuid
|
||||||
|
s = peer.get_services_by_uuid(service1.uuid)
|
||||||
|
assert len(s) == 1
|
||||||
|
assert len(s[0].characteristics) == 2
|
||||||
|
c = peer.get_characteristics_by_uuid(uuid=characteristic2.uuid, service=s[0])
|
||||||
|
assert len(c) == 1
|
||||||
|
assert len(c[0].descriptors) == 2
|
||||||
|
s = peer.get_services_by_uuid(service2.uuid)
|
||||||
|
assert len(s) == 1
|
||||||
|
assert len(s[0].characteristics) == 0
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mtu_exchange():
|
async def test_mtu_exchange():
|
||||||
@@ -1146,6 +1197,56 @@ def test_get_attribute_group():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_characteristics_by_uuid():
|
||||||
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
|
characteristic1 = Characteristic(
|
||||||
|
'1234',
|
||||||
|
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
bytes([1, 2, 3]),
|
||||||
|
)
|
||||||
|
characteristic2 = Characteristic(
|
||||||
|
'5678',
|
||||||
|
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||||
|
Characteristic.READABLE,
|
||||||
|
bytes([1, 2, 3]),
|
||||||
|
)
|
||||||
|
service1 = Service(
|
||||||
|
'ABCD',
|
||||||
|
[characteristic1, characteristic2],
|
||||||
|
)
|
||||||
|
service2 = Service(
|
||||||
|
'FFFF',
|
||||||
|
[characteristic1],
|
||||||
|
)
|
||||||
|
|
||||||
|
server.add_services([service1, service2])
|
||||||
|
|
||||||
|
await client.power_on()
|
||||||
|
await server.power_on()
|
||||||
|
connection = await client.connect(server.random_address)
|
||||||
|
peer = Peer(connection)
|
||||||
|
|
||||||
|
await peer.discover_services()
|
||||||
|
await peer.discover_characteristics()
|
||||||
|
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
|
||||||
|
assert len(c) == 2
|
||||||
|
assert isinstance(c[0], CharacteristicProxy)
|
||||||
|
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
|
||||||
|
assert len(c) == 1
|
||||||
|
assert isinstance(c[0], CharacteristicProxy)
|
||||||
|
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
|
||||||
|
assert len(c) == 0
|
||||||
|
|
||||||
|
s = peer.get_services_by_uuid(uuid=UUID('ABCD'))
|
||||||
|
assert len(s) == 1
|
||||||
|
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=s[0])
|
||||||
|
assert len(s) == 1
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ def test_import():
|
|||||||
core,
|
core,
|
||||||
crypto,
|
crypto,
|
||||||
device,
|
device,
|
||||||
gap,
|
|
||||||
hci,
|
hci,
|
||||||
hfp,
|
hfp,
|
||||||
host,
|
host,
|
||||||
@@ -41,6 +40,22 @@ def test_import():
|
|||||||
utils,
|
utils,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from bumble.profiles import (
|
||||||
|
ascs,
|
||||||
|
bap,
|
||||||
|
bass,
|
||||||
|
battery_service,
|
||||||
|
cap,
|
||||||
|
csip,
|
||||||
|
device_information_service,
|
||||||
|
gap,
|
||||||
|
heart_rate_service,
|
||||||
|
le_audio,
|
||||||
|
pacs,
|
||||||
|
pbp,
|
||||||
|
vcp,
|
||||||
|
)
|
||||||
|
|
||||||
assert att
|
assert att
|
||||||
assert bridge
|
assert bridge
|
||||||
assert company_ids
|
assert company_ids
|
||||||
@@ -48,7 +63,6 @@ def test_import():
|
|||||||
assert core
|
assert core
|
||||||
assert crypto
|
assert crypto
|
||||||
assert device
|
assert device
|
||||||
assert gap
|
|
||||||
assert hci
|
assert hci
|
||||||
assert hfp
|
assert hfp
|
||||||
assert host
|
assert host
|
||||||
@@ -61,6 +75,20 @@ def test_import():
|
|||||||
assert transport
|
assert transport
|
||||||
assert utils
|
assert utils
|
||||||
|
|
||||||
|
assert ascs
|
||||||
|
assert bap
|
||||||
|
assert bass
|
||||||
|
assert battery_service
|
||||||
|
assert cap
|
||||||
|
assert csip
|
||||||
|
assert device_information_service
|
||||||
|
assert gap
|
||||||
|
assert heart_rate_service
|
||||||
|
assert le_audio
|
||||||
|
assert pacs
|
||||||
|
assert pbp
|
||||||
|
assert vcp
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_app_imports():
|
def test_app_imports():
|
||||||
|
|||||||
39
tests/le_audio_test.py
Normal file
39
tests/le_audio_test.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2021-2024 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
from bumble.profiles import le_audio
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_metadata():
|
||||||
|
metadata = le_audio.Metadata(
|
||||||
|
entries=[
|
||||||
|
le_audio.Metadata.Entry(
|
||||||
|
tag=le_audio.Metadata.Tag.PROGRAM_INFO,
|
||||||
|
data=b'',
|
||||||
|
),
|
||||||
|
le_audio.Metadata.Entry(
|
||||||
|
tag=le_audio.Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
|
||||||
|
data=bytes([0, 0]),
|
||||||
|
),
|
||||||
|
le_audio.Metadata.Entry(
|
||||||
|
tag=le_audio.Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
|
||||||
|
data=bytes([1, 2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert le_audio.Metadata.from_bytes(bytes(metadata)) == metadata
|
||||||
132
tests/mcp_test.py
Normal file
132
tests/mcp_test.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright 2021-2023 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import struct
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from bumble import device
|
||||||
|
from bumble.profiles import mcp
|
||||||
|
from tests.test_utils import TwoDevices
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
TIMEOUT = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GmcsContext:
|
||||||
|
devices: TwoDevices
|
||||||
|
client: mcp.GenericMediaControlServiceProxy
|
||||||
|
server: mcp.GenericMediaControlService
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def gmcs_context():
|
||||||
|
devices = TwoDevices()
|
||||||
|
server = mcp.GenericMediaControlService()
|
||||||
|
devices[0].add_service(server)
|
||||||
|
|
||||||
|
await devices.setup_connection()
|
||||||
|
devices.connections[0].encryption = 1
|
||||||
|
devices.connections[1].encryption = 1
|
||||||
|
peer = device.Peer(devices.connections[1])
|
||||||
|
client = await peer.discover_service_and_create_proxy(
|
||||||
|
mcp.GenericMediaControlServiceProxy
|
||||||
|
)
|
||||||
|
await client.subscribe_characteristics()
|
||||||
|
|
||||||
|
return GmcsContext(devices=devices, server=server, client=client)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_media_state(gmcs_context):
|
||||||
|
state = asyncio.Queue()
|
||||||
|
gmcs_context.client.on('media_state', state.put_nowait)
|
||||||
|
|
||||||
|
await gmcs_context.devices[0].notify_subscribers(
|
||||||
|
gmcs_context.server.media_state_characteristic,
|
||||||
|
value=bytes([mcp.MediaState.PLAYING]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == mcp.MediaState.PLAYING
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_track_title(gmcs_context):
|
||||||
|
state = asyncio.Queue()
|
||||||
|
gmcs_context.client.on('track_title', state.put_nowait)
|
||||||
|
|
||||||
|
await gmcs_context.devices[0].notify_subscribers(
|
||||||
|
gmcs_context.server.track_title_characteristic,
|
||||||
|
value="My Song".encode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == "My Song"
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_track_duration(gmcs_context):
|
||||||
|
state = asyncio.Queue()
|
||||||
|
gmcs_context.client.on('track_duration', state.put_nowait)
|
||||||
|
|
||||||
|
await gmcs_context.devices[0].notify_subscribers(
|
||||||
|
gmcs_context.server.track_duration_characteristic,
|
||||||
|
value=struct.pack("<i", 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_track_position(gmcs_context):
|
||||||
|
state = asyncio.Queue()
|
||||||
|
gmcs_context.client.on('track_position', state.put_nowait)
|
||||||
|
|
||||||
|
await gmcs_context.devices[0].notify_subscribers(
|
||||||
|
gmcs_context.server.track_position_characteristic,
|
||||||
|
value=struct.pack("<i", 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_media_control_point(gmcs_context):
|
||||||
|
assert (
|
||||||
|
await asyncio.wait_for(
|
||||||
|
gmcs_context.client.write_control_point(mcp.MediaControlPointOpcode.PAUSE),
|
||||||
|
TIMEOUT,
|
||||||
|
)
|
||||||
|
) == mcp.MediaControlPointResultCode.SUCCESS
|
||||||
@@ -17,13 +17,17 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from bumble import smp
|
from bumble import smp
|
||||||
|
from bumble import pairing
|
||||||
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
||||||
from bumble.pairing import OobData, OobSharedData, LeRole
|
from bumble.pairing import OobData, OobSharedData, LeRole
|
||||||
from bumble.hci import Address
|
from bumble.hci import Address
|
||||||
from bumble.core import AdvertisingData
|
from bumble.core import AdvertisingData
|
||||||
|
from bumble.device import Device
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
@@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
|
|||||||
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
|
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'identity_address_type, public_address, random_address, expected_identity_address',
|
||||||
|
[
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Address.ANY,
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
pairing.PairingConfig.AddressType.PUBLIC,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
pairing.PairingConfig.AddressType.RANDOM,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_identity_address_command(
|
||||||
|
identity_address_type: Optional[pairing.PairingConfig.AddressType],
|
||||||
|
public_address: Address,
|
||||||
|
random_address: Address,
|
||||||
|
expected_identity_address: Address,
|
||||||
|
):
|
||||||
|
device = Device()
|
||||||
|
device.public_address = public_address
|
||||||
|
device.static_address = random_address
|
||||||
|
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
|
||||||
|
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
|
||||||
|
|
||||||
|
with mock.patch.object(session, 'send_command') as mock_method:
|
||||||
|
session.send_identity_address_command()
|
||||||
|
|
||||||
|
actual_command = mock_method.call_args.args[0]
|
||||||
|
assert actual_command.addr_type == expected_identity_address.address_type
|
||||||
|
assert actual_command.bd_addr == expected_identity_address
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_ecc()
|
test_ecc()
|
||||||
|
|||||||
@@ -51,3 +51,5 @@ Example:
|
|||||||
|
|
||||||
NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory.
|
NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory.
|
||||||
Make a copy of the built `.whl` file in the `web` directory.
|
Make a copy of the built `.whl` file in the `web` directory.
|
||||||
|
|
||||||
|
Tip: During web developement, disable caching. [Chrome](https://stackoverflow.com/a/7000899]) / [Firefiox](https://stackoverflow.com/a/289771)
|
||||||
1
web/favicon.ico
Symbolic link
1
web/favicon.ico
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../docs/images/favicon.ico
|
||||||
@@ -15,12 +15,21 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
import pyee
|
||||||
|
|
||||||
from bumble.device import Device
|
from bumble.device import Device
|
||||||
from bumble.hci import HCI_Reset_Command
|
from bumble.hci import HCI_Reset_Command
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Scanner:
|
class Scanner(pyee.EventEmitter):
|
||||||
|
"""
|
||||||
|
Scanner web app
|
||||||
|
|
||||||
|
Emitted events:
|
||||||
|
update: Emit when new `ScanEntry` are available.
|
||||||
|
"""
|
||||||
|
|
||||||
class ScanEntry:
|
class ScanEntry:
|
||||||
def __init__(self, advertisement):
|
def __init__(self, advertisement):
|
||||||
self.address = advertisement.address.to_string(False)
|
self.address = advertisement.address.to_string(False)
|
||||||
@@ -39,13 +48,12 @@ class Scanner:
|
|||||||
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||||
)
|
)
|
||||||
self.scan_entries = {}
|
self.scan_entries = {}
|
||||||
self.listeners = {}
|
|
||||||
self.device.on('advertisement', self.on_advertisement)
|
self.device.on('advertisement', self.on_advertisement)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
print('### Starting Scanner')
|
print('### Starting Scanner')
|
||||||
self.scan_entries = {}
|
self.scan_entries = {}
|
||||||
self.emit_update()
|
self.emit('update', self.scan_entries)
|
||||||
await self.device.power_on()
|
await self.device.power_on()
|
||||||
await self.device.start_scanning()
|
await self.device.start_scanning()
|
||||||
print('### Scanner started')
|
print('### Scanner started')
|
||||||
@@ -56,16 +64,9 @@ class Scanner:
|
|||||||
await self.device.power_off()
|
await self.device.power_off()
|
||||||
print('### Scanner stopped')
|
print('### Scanner stopped')
|
||||||
|
|
||||||
def emit_update(self):
|
|
||||||
if listener := self.listeners.get('update'):
|
|
||||||
listener(list(self.scan_entries.values()))
|
|
||||||
|
|
||||||
def on(self, event_name, listener):
|
|
||||||
self.listeners[event_name] = listener
|
|
||||||
|
|
||||||
def on_advertisement(self, advertisement):
|
def on_advertisement(self, advertisement):
|
||||||
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
|
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
|
||||||
self.emit_update()
|
self.emit('update', self.scan_entries)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user