Merge pull request #372 from zxzxwu/source

ASCS Source Implementation
This commit is contained in:
zxzxwu
2024-05-29 13:17:51 +08:00
committed by GitHub
14 changed files with 978 additions and 155 deletions

View File

@@ -1,6 +1,7 @@
{ {
"cSpell.words": [ "cSpell.words": [
"Abortable", "Abortable",
"aiohttp",
"altsetting", "altsetting",
"ansiblue", "ansiblue",
"ansicyan", "ansicyan",
@@ -9,6 +10,7 @@
"ansired", "ansired",
"ansiyellow", "ansiyellow",
"appendleft", "appendleft",
"ascs",
"ASHA", "ASHA",
"asyncio", "asyncio",
"ATRAC", "ATRAC",
@@ -43,6 +45,7 @@
"keyup", "keyup",
"levelname", "levelname",
"libc", "libc",
"liblc",
"libusb", "libusb",
"MITM", "MITM",
"MSBC", "MSBC",
@@ -78,6 +81,7 @@
"unmuted", "unmuted",
"usbmodem", "usbmodem",
"vhci", "vhci",
"wasmtime",
"websockets", "websockets",
"xcursor", "xcursor",
"ycursor" "ycursor"

577
apps/lea_unicast/app.py Normal file
View File

@@ -0,0 +1,577 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
import bumble
from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1],
min_octets_per_codec_frame=30,
max_octets_per_codec_frame=100,
supported_max_codec_frames_per_sdu=1,
),
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task(
filename: str,
sdu_length: int,
frame_duration_us: int,
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = aiohttp.web.Application()
app.add_routes(
[
aiohttp.web.get('/', self.get_static),
aiohttp.web.get('/index.html', self.get_static),
aiohttp.web.get('/channel', self.get_channel),
]
)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/index.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.lea_unicast")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: Optional[str],
ui_port: int,
transport: str,
lc3_input_file_path: str,
):
self.device_config_path = device_config_path
self.transport = transport
self.lc3_input_file_path = lc3_input_file_path
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
async def run(self) -> None:
await self.ui_server.start_http()
async with await open_transport(self.transport) as hci_transport:
# Create a device
if self.device_config_path:
device_config = DeviceConfiguration.from_file(self.device_config_path)
else:
device_config = DeviceConfiguration(
name="Bumble LE Headphone",
class_of_device=0x244418,
keystore="JsonKeyStore",
advertising_interval_min=25,
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
)
device_config.le_enabled = True
device_config.cis_enabled = True
self.device = Device.from_config_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
self.device.add_service(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
available_sink_context=bap.ContextType(0xFFFF), # All context types
sink_audio_locations=(
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
),
sink_pac=[_sink_pac_record()],
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
source_pac=[_source_pac_record()],
)
)
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
filename=self.lc3_input_file_path,
sdu_length=(
codec_config.codec_frames_per_sdu
* codec_config.octets_per_codec_frame
),
frame_duration_us=codec_config.frame_duration.us,
device=self.device,
cis_handle=ase.cis_link.handle,
),
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()
await self.device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await hci_transport.source.terminated
@click.command()
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
@click.argument('lc3_file')
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
"""Run the speaker."""
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,68 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
</div>
</nav>
<br>
<div class="container">
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
<button class="btn btn-primary" type="button" disabled>
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
<span role="status" id="ws-status">WebSocket Connecting...</span>
</button>
</div>
<script>
let player = null;
const wsStatus = document.getElementById("ws-status");
const wsStatusSpinner = document.getElementById("ws-status-spinner");
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
socket.binaryType = "arraybuffer";
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
console.log(`channel MESSAGE: ${message.data}`);
} else {
console.log(typeof (message.data))
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
socket.onopen = (message) => {
wsStatusSpinner.remove();
wsStatus.textContent = "WebSocket Connected";
}
socket.onclose = (message) => {
wsStatus.textContent = "WebSocket Disconnected";
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 2,
sampleRate: 48000,
flushTime: 10,
});
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
</script>
</div>
</body>
</html>

BIN
apps/lea_unicast/liblc3.wasm Executable file

Binary file not shown.

View File

@@ -23,7 +23,13 @@ import json
import asyncio import asyncio
import logging import logging
import secrets import secrets
from contextlib import asynccontextmanager, AsyncExitStack, closing import sys
from contextlib import (
asynccontextmanager,
AsyncExitStack,
closing,
AbstractAsyncContextManager,
)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from collections.abc import Iterable from collections.abc import Iterable
from typing import ( from typing import (
@@ -961,8 +967,9 @@ class ScoLink(CompositeEventEmitter):
acl_connection: Connection acl_connection: Connection
handle: int handle: int
link_type: int link_type: int
sink: Optional[Callable[[HCI_SynchronousDataPacket], Any]] = None
def __post_init__(self): def __post_init__(self) -> None:
super().__init__() super().__init__()
async def disconnect( async def disconnect(
@@ -984,8 +991,9 @@ class CisLink(CompositeEventEmitter):
cis_id: int # CIS ID assigned by Central device cis_id: int # CIS ID assigned by Central device
cig_id: int # CIG ID assigned by Central device cig_id: int # CIG ID assigned by Central device
state: State = State.PENDING state: State = State.PENDING
sink: Optional[Callable[[HCI_IsoDataPacket], Any]] = None
def __post_init__(self): def __post_init__(self) -> None:
super().__init__() super().__init__()
async def disconnect( async def disconnect(
@@ -1533,6 +1541,12 @@ class Device(CompositeEventEmitter):
Address.ANY: [] Address.ANY: []
} # Futures, by BD address OR [Futures] for Address.ANY } # Futures, by BD address OR [Futures] for Address.ANY
# In Python <= 3.9 + Rust Runtime, asyncio.Lock cannot be properly initiated.
if sys.version_info >= (3, 10):
self._cis_lock = asyncio.Lock()
else:
self._cis_lock = AsyncExitStack()
# Own address type cache # Own address type cache
self.connect_own_address_type = None self.connect_own_address_type = None
@@ -3406,49 +3420,71 @@ class Device(CompositeEventEmitter):
for cis_handle, _ in cis_acl_pairs for cis_handle, _ in cis_acl_pairs
} }
@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None: def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(cis_link.handle): if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link) pending_future.set_result(cis_link)
result = await self.send_command( def on_cis_establishment_failure(cis_handle: int, status: int) -> None:
if pending_future := pending_cis_establishments.get(cis_handle):
pending_future.set_exception(HCI_Error(status))
watcher.on(self, 'cis_establishment', on_cis_establishment)
watcher.on(self, 'cis_establishment_failure', on_cis_establishment_failure)
await self.send_command(
HCI_LE_Create_CIS_Command( HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs], cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs], acl_connection_handle=[p[1] for p in cis_acl_pairs],
), ),
check_result=True,
) )
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
return await asyncio.gather(*pending_cis_establishments.values()) return await asyncio.gather(*pending_cis_establishments.values())
# [LE only] # [LE only]
@experimental('Only for testing.') @experimental('Only for testing.')
async def accept_cis_request(self, handle: int) -> CisLink: async def accept_cis_request(self, handle: int) -> CisLink:
result = await self.send_command( """[LE Only] Accepts an incoming CIS request.
HCI_LE_Accept_CIS_Request_Command(connection_handle=handle),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Accept_CIS_Request_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
pending_cis_establishment = asyncio.get_running_loop().create_future() When the specified CIS handle is already created, this method returns the
existed CIS link object immediately.
with closing(EventWatcher()) as watcher: Args:
handle: CIS handle to accept.
@watcher.on(self, 'cis_establishment') Returns:
def on_cis_establishment(cis_link: CisLink) -> None: CIS link object on the given handle.
if cis_link.handle == handle: """
pending_cis_establishment.set_result(cis_link) if not (cis_link := self.cis_links.get(handle)):
raise InvalidStateError(f'No pending CIS request of handle {handle}')
return await pending_cis_establishment # There might be multiple ASE sharing a CIS channel.
# If one of them has accepted the request, the others should just leverage it.
async with self._cis_lock:
if cis_link.state == CisLink.State.ESTABLISHED:
return cis_link
with closing(EventWatcher()) as watcher:
pending_establishment = asyncio.get_running_loop().create_future()
def on_establishment() -> None:
pending_establishment.set_result(None)
def on_establishment_failure(status: int) -> None:
pending_establishment.set_exception(HCI_Error(status))
watcher.on(cis_link, 'establishment', on_establishment)
watcher.on(cis_link, 'establishment_failure', on_establishment_failure)
await self.send_command(
HCI_LE_Accept_CIS_Request_Command(connection_handle=handle),
check_result=True,
)
await pending_establishment
return cis_link
# Mypy believes this is reachable when context is an ExitStack.
raise InvalidStateError('Unreachable')
# [LE only] # [LE only]
@experimental('Only for testing.') @experimental('Only for testing.')
@@ -3457,15 +3493,10 @@ class Device(CompositeEventEmitter):
handle: int, handle: int,
reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
) -> None: ) -> None:
result = await self.send_command( await self.send_command(
HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason), HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason),
check_result=True,
) )
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Reject_CIS_Request_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
async def get_remote_le_features(self, connection: Connection) -> LeFeatureMask: async def get_remote_le_features(self, connection: Connection) -> LeFeatureMask:
"""[LE Only] Reads remote LE supported features. """[LE Only] Reads remote LE supported features.
@@ -3485,11 +3516,17 @@ class Device(CompositeEventEmitter):
if handle == connection.handle: if handle == connection.handle:
read_feature_future.set_result(LeFeatureMask(features)) read_feature_future.set_result(LeFeatureMask(features))
def on_failure(handle: int, status: int):
if handle == connection.handle:
read_feature_future.set_exception(HCI_Error(status))
watcher.on(self.host, 'le_remote_features', on_le_remote_features) watcher.on(self.host, 'le_remote_features', on_le_remote_features)
watcher.on(self.host, 'le_remote_features_failure', on_failure)
await self.send_command( await self.send_command(
HCI_LE_Read_Remote_Features_Command( HCI_LE_Read_Remote_Features_Command(
connection_handle=connection.handle connection_handle=connection.handle
), ),
check_result=True,
) )
return await read_feature_future return await read_feature_future
@@ -4111,8 +4148,8 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
@experimental('Only for testing') @experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle): if (sco_link := self.sco_links.get(sco_handle)) and sco_link.sink:
sco_link.emit('pdu', packet) sco_link.sink(packet)
# [LE only] # [LE only]
@host_event_handler @host_event_handler
@@ -4168,15 +4205,15 @@ class Device(CompositeEventEmitter):
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle): if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure') cis_link.emit('establishment_failure', status)
self.emit('cis_establishment_failure', cis_handle, status) self.emit('cis_establishment_failure', cis_handle, status)
# [LE only] # [LE only]
@host_event_handler @host_event_handler
@experimental('Only for testing') @experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle): if (cis_link := self.cis_links.get(handle)) and cis_link.sink:
cis_link.emit('pdu', packet) cis_link.sink(packet)
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle

View File

@@ -23,7 +23,7 @@ import functools
import logging import logging
import secrets import secrets
import struct import struct
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, ClassVar
from bumble import crypto from bumble import crypto
from .colors import color from .colors import color
@@ -2003,7 +2003,7 @@ class HCI_Packet:
Abstract Base class for HCI packets Abstract Base class for HCI packets
''' '''
hci_packet_type: int hci_packet_type: ClassVar[int]
@staticmethod @staticmethod
def from_bytes(packet: bytes) -> HCI_Packet: def from_bytes(packet: bytes) -> HCI_Packet:
@@ -6192,12 +6192,23 @@ class HCI_SynchronousDataPacket(HCI_Packet):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class HCI_IsoDataPacket(HCI_Packet): class HCI_IsoDataPacket(HCI_Packet):
''' '''
See Bluetooth spec @ 5.4.5 HCI ISO Data Packets See Bluetooth spec @ 5.4.5 HCI ISO Data Packets
''' '''
hci_packet_type = HCI_ISO_DATA_PACKET hci_packet_type: ClassVar[int] = HCI_ISO_DATA_PACKET
connection_handle: int
data_total_length: int
iso_sdu_fragment: bytes
pb_flag: int
ts_flag: int = 0
time_stamp: Optional[int] = None
packet_sequence_number: Optional[int] = None
iso_sdu_length: Optional[int] = None
packet_status_flag: Optional[int] = None
@staticmethod @staticmethod
def from_bytes(packet: bytes) -> HCI_IsoDataPacket: def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
@@ -6241,28 +6252,6 @@ class HCI_IsoDataPacket(HCI_Packet):
iso_sdu_fragment=iso_sdu_fragment, iso_sdu_fragment=iso_sdu_fragment,
) )
def __init__(
self,
connection_handle: int,
pb_flag: int,
ts_flag: int,
data_total_length: int,
time_stamp: Optional[int],
packet_sequence_number: Optional[int],
iso_sdu_length: Optional[int],
packet_status_flag: Optional[int],
iso_sdu_fragment: bytes,
) -> None:
self.connection_handle = connection_handle
self.pb_flag = pb_flag
self.ts_flag = ts_flag
self.data_total_length = data_total_length
self.time_stamp = time_stamp
self.packet_sequence_number = packet_sequence_number
self.iso_sdu_length = iso_sdu_length
self.packet_status_flag = packet_status_flag
self.iso_sdu_fragment = iso_sdu_fragment
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.to_bytes() return self.to_bytes()

View File

@@ -721,14 +721,16 @@ class Host(AbortableEventEmitter):
for connection_handle, num_completed_packets in zip( for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets event.connection_handles, event.num_completed_packets
): ):
if not (connection := self.connections.get(connection_handle)): if connection := self.connections.get(connection_handle):
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
elif not (
self.cis_links.get(connection_handle)
or self.sco_links.get(connection_handle)
):
logger.warning( logger.warning(
'received packet completion event for unknown handle ' 'received packet completion event for unknown handle '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
) )
continue
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
# Classic only # Classic only
def on_hci_connection_request_event(self, event): def on_hci_connection_request_event(self, event):

View File

@@ -78,6 +78,10 @@ class AudioLocation(enum.IntFlag):
LEFT_SURROUND = 0x04000000 LEFT_SURROUND = 0x04000000
RIGHT_SURROUND = 0x08000000 RIGHT_SURROUND = 0x08000000
@property
def channel_count(self) -> int:
return bin(self.value).count('1')
class AudioInputType(enum.IntEnum): class AudioInputType(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type''' '''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type'''
@@ -218,6 +222,13 @@ class FrameDuration(enum.IntEnum):
DURATION_7500_US = 0x00 DURATION_7500_US = 0x00
DURATION_10000_US = 0x01 DURATION_10000_US = 0x01
@property
def us(self) -> int:
return {
FrameDuration.DURATION_7500_US: 7500,
FrameDuration.DURATION_10000_US: 10000,
}[self]
class SupportedFrameDuration(enum.IntFlag): class SupportedFrameDuration(enum.IntFlag):
'''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration''' '''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration'''
@@ -534,7 +545,7 @@ class CodecSpecificCapabilities:
supported_sampling_frequencies: SupportedSamplingFrequency supported_sampling_frequencies: SupportedSamplingFrequency
supported_frame_durations: SupportedFrameDuration supported_frame_durations: SupportedFrameDuration
supported_audio_channel_counts: Sequence[int] supported_audio_channel_count: Sequence[int]
min_octets_per_codec_frame: int min_octets_per_codec_frame: int
max_octets_per_codec_frame: int max_octets_per_codec_frame: int
supported_max_codec_frames_per_sdu: int supported_max_codec_frames_per_sdu: int
@@ -543,7 +554,7 @@ class CodecSpecificCapabilities:
def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities:
offset = 0 offset = 0
# Allowed default values. # Allowed default values.
supported_audio_channel_counts = [1] supported_audio_channel_count = [1]
supported_max_codec_frames_per_sdu = 1 supported_max_codec_frames_per_sdu = 1
while offset < len(data): while offset < len(data):
length, type = struct.unpack_from('BB', data, offset) length, type = struct.unpack_from('BB', data, offset)
@@ -556,7 +567,7 @@ class CodecSpecificCapabilities:
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION: elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value) supported_frame_durations = SupportedFrameDuration(value)
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_counts = bits_to_channel_counts(value) supported_audio_channel_count = bits_to_channel_counts(value)
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16 max_octets_per_sample = value >> 16
@@ -567,7 +578,7 @@ class CodecSpecificCapabilities:
return CodecSpecificCapabilities( return CodecSpecificCapabilities(
supported_sampling_frequencies=supported_sampling_frequencies, supported_sampling_frequencies=supported_sampling_frequencies,
supported_frame_durations=supported_frame_durations, supported_frame_durations=supported_frame_durations,
supported_audio_channel_counts=supported_audio_channel_counts, supported_audio_channel_count=supported_audio_channel_count,
min_octets_per_codec_frame=min_octets_per_sample, min_octets_per_codec_frame=min_octets_per_sample,
max_octets_per_codec_frame=max_octets_per_sample, max_octets_per_codec_frame=max_octets_per_sample,
supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu, supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu,
@@ -584,7 +595,7 @@ class CodecSpecificCapabilities:
self.supported_frame_durations, self.supported_frame_durations,
2, 2,
CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT, CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT,
channel_counts_to_bits(self.supported_audio_channel_counts), channel_counts_to_bits(self.supported_audio_channel_count),
5, 5,
CodecSpecificCapabilities.Type.OCTETS_PER_FRAME, CodecSpecificCapabilities.Type.OCTETS_PER_FRAME,
self.min_octets_per_codec_frame, self.min_octets_per_codec_frame,
@@ -870,15 +881,22 @@ class AseStateMachine(gatt.Characteristic):
cig_id: int, cig_id: int,
cis_id: int, cis_id: int,
) -> None: ) -> None:
if cis_id == self.cis_id and self.state == self.State.ENABLING: if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on( acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle) 'flush', self.service.device.accept_cis_request(cis_handle)
) )
def on_cis_establishment(self, cis_link: device.CisLink) -> None: def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: if (
self.state = self.State.STREAMING cis_link.cig_id == self.cig_id
self.cis_link = cis_link and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established(): async def post_cis_established():
await self.service.device.send_command( await self.service.device.send_command(
@@ -891,9 +909,15 @@ class AseStateMachine(gatt.Characteristic):
codec_configuration=b'', codec_configuration=b'',
) )
) )
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established()) cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec( def on_config_codec(
self, self,
@@ -991,11 +1015,17 @@ class AseStateMachine(gatt.Characteristic):
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE, AseReasonCode.NONE,
) )
self.state = self.State.DISABLING if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.DISABLING: if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE, AseReasonCode.NONE,
@@ -1046,6 +1076,7 @@ class AseStateMachine(gatt.Characteristic):
def state(self, new_state: State) -> None: def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state self._state = new_state
self.emit('state_change')
@property @property
def value(self): def value(self):
@@ -1118,6 +1149,7 @@ class AudioStreamControlService(gatt.TemplateService):
ase_state_machines: Dict[int, AseStateMachine] ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__( def __init__(
self, self,
@@ -1155,7 +1187,16 @@ class AudioStreamControlService(gatt.TemplateService):
else: else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data): def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data) operation = ASE_Operation.from_bytes(data)
responses = [] responses = []
logger.debug(f'*** ASCS Write {operation} ***') logger.debug(f'*** ASCS Write {operation} ***')

View File

@@ -26,7 +26,7 @@ import websockets
from typing import Optional from typing import Optional
import bumble.core import bumble.core
from bumble.device import Device from bumble.device import Device, ScoLink
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
@@ -217,11 +217,11 @@ async def main() -> None:
1: hfp.make_ag_sdp_records(1, channel, configuration) 1: hfp.make_ag_sdp_records(1, channel, configuration)
} }
def on_sco_connection(sco_link): def on_sco_connection(sco_link: ScoLink):
assert ag_protocol assert ag_protocol
on_sco_state_change(ag_protocol.active_codec) on_sco_state_change(ag_protocol.active_codec)
sco_link.on('disconnection', lambda _: on_sco_state_change(0)) sco_link.on('disconnection', lambda _: on_sco_state_change(0))
sco_link.on('pdu', on_sco_packet) sco_link.sink = on_sco_packet
device.on('sco_connection', on_sco_connection) device.on('sco_connection', on_sco_connection)
if len(sys.argv) >= 4: if len(sys.argv) >= 4:

View File

@@ -16,20 +16,28 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import datetime
import functools
import logging import logging
import sys import sys
import os import os
import io
import struct import struct
import secrets import secrets
from typing import Dict
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Device, CisLink from bumble.device import Device
from bumble.hci import ( from bumble.hci import (
CodecID, CodecID,
CodingFormat, CodingFormat,
HCI_IsoDataPacket, HCI_IsoDataPacket,
) )
from bumble.profiles.bap import ( from bumble.profiles.bap import (
AseStateMachine,
UnicastServerAdvertisingData, UnicastServerAdvertisingData,
CodecSpecificConfiguration,
CodecSpecificCapabilities, CodecSpecificCapabilities,
ContextType, ContextType,
AudioLocation, AudioLocation,
@@ -45,6 +53,32 @@ from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
def _sink_pac_record() -> PacRecord:
return PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_8000
| SupportedSamplingFrequency.FREQ_16000
| SupportedSamplingFrequency.FREQ_24000
| SupportedSamplingFrequency.FREQ_32000
| SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_7500_US_SUPPORTED
| SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
file_outputs: Dict[AseStateMachine, io.BufferedWriter] = {}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main() -> None:
if len(sys.argv) < 3: if len(sys.argv) < 3:
@@ -71,49 +105,17 @@ async def main() -> None:
PublishedAudioCapabilitiesService( PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED, supported_source_context=ContextType.PROHIBITED,
available_source_context=ContextType.PROHIBITED, available_source_context=ContextType.PROHIBITED,
supported_sink_context=ContextType.MEDIA, supported_sink_context=ContextType(0xFF), # All context types
available_sink_context=ContextType.MEDIA, available_sink_context=ContextType(0xFF), # All context types
sink_audio_locations=( sink_audio_locations=(
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
), ),
sink_pac=[ sink_pac=[_sink_pac_record()],
# Codec Capability Setting 16_2
PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_16000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_counts=[1],
min_octets_per_codec_frame=40,
max_octets_per_codec_frame=40,
supported_max_codec_frames_per_sdu=1,
),
),
# Codec Capability Setting 24_2
PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_counts=[1],
min_octets_per_codec_frame=120,
max_octets_per_codec_frame=120,
supported_max_codec_frames_per_sdu=1,
),
),
],
) )
) )
device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) ascs = AudioStreamControlService(device, sink_ase_id=[1], source_ase_id=[2])
device.add_service(ascs)
advertising_data = ( advertising_data = (
bytes( bytes(
@@ -143,44 +145,57 @@ async def main() -> None:
+ csis.get_advertising_data() + csis.get_advertising_data()
+ bytes(UnicastServerAdvertisingData()) + bytes(UnicastServerAdvertisingData())
) )
subprocess = await asyncio.create_subprocess_shell(
f'dlc3 | ffplay pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdin = subprocess.stdin def on_pdu(ase: AseStateMachine, pdu: HCI_IsoDataPacket):
assert stdin
# Write a fake LC3 header to dlc3.
stdin.write(
bytes([0x1C, 0xCC]) # Header.
+ struct.pack(
'<HHHHHHI',
18, # Header length.
48000 // 100, # Sampling Rate(/100Hz).
0, # Bitrate(unused).
1, # Channels.
10000 // 10, # Frame duration(/10us).
0, # RFU.
0x0FFFFFFF, # Frame counts.
)
)
def on_pdu(pdu: HCI_IsoDataPacket):
# LC3 format: |frame_length(2)| + |frame(length)|. # LC3 format: |frame_length(2)| + |frame(length)|.
sdu = b''
if pdu.iso_sdu_length: if pdu.iso_sdu_length:
stdin.write(struct.pack('<H', pdu.iso_sdu_length)) sdu = struct.pack('<H', pdu.iso_sdu_length)
stdin.write(pdu.iso_sdu_fragment) sdu += pdu.iso_sdu_fragment
file_outputs[ase].write(sdu)
def on_cis(cis_link: CisLink): def on_ase_state_change(
cis_link.on('pdu', on_pdu) state: AseStateMachine.State,
ase: AseStateMachine,
) -> None:
if state != AseStateMachine.State.STREAMING:
if file_output := file_outputs.pop(ase):
file_output.close()
else:
file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb')
codec_configuration = ase.codec_specific_configuration
assert isinstance(codec_configuration, CodecSpecificConfiguration)
# Write a LC3 header.
file_output.write(
bytes([0x1C, 0xCC]) # Header.
+ struct.pack(
'<HHHHHHI',
18, # Header length.
codec_configuration.sampling_frequency.hz
// 100, # Sampling Rate(/100Hz).
0, # Bitrate(unused).
bin(codec_configuration.audio_channel_allocation).count(
'1'
), # Channels.
codec_configuration.frame_duration.us
// 10, # Frame duration(/10us).
0, # RFU.
0x0FFFFFFF, # Frame counts.
)
)
file_outputs[ase] = file_output
assert ase.cis_link
ase.cis_link.sink = functools.partial(on_pdu, ase)
device.once('cis_establishment', on_cis) for ase in ascs.ase_state_machines.values():
ase.on(
'state_change',
functools.partial(on_ase_state_change, ase=ase),
)
await device.create_advertising_set( await device.create_advertising_set(
advertising_data=advertising_data, advertising_data=advertising_data,
auto_restart=True,
) )
await hci_transport.source.terminated await hci_transport.source.terminated

View File

@@ -102,7 +102,7 @@ async def main() -> None:
supported_frame_durations=( supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED SupportedFrameDuration.DURATION_10000_US_SUPPORTED
), ),
supported_audio_channel_counts=[1], supported_audio_channel_count=[1],
min_octets_per_codec_frame=120, min_octets_per_codec_frame=120,
max_octets_per_codec_frame=120, max_octets_per_codec_frame=120,
supported_max_codec_frames_per_sdu=1, supported_max_codec_frames_per_sdu=1,

View File

@@ -96,6 +96,7 @@ development =
types-appdirs >= 1.4.3 types-appdirs >= 1.4.3
types-invoke >= 1.7.3 types-invoke >= 1.7.3
types-protobuf >= 4.21.0 types-protobuf >= 4.21.0
wasmtime == 20.0.0
avatar = avatar =
pandora-avatar == 0.0.9 pandora-avatar == 0.0.9
rootcanal == 1.10.0 ; python_version>='3.10' rootcanal == 1.10.0 ; python_version>='3.10'

View File

@@ -72,7 +72,7 @@ def test_codec_specific_capabilities() -> None:
cap = CodecSpecificCapabilities( cap = CodecSpecificCapabilities(
supported_sampling_frequencies=SAMPLE_FREQUENCY, supported_sampling_frequencies=SAMPLE_FREQUENCY,
supported_frame_durations=FRAME_SURATION, supported_frame_durations=FRAME_SURATION,
supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, supported_audio_channel_count=AUDIO_CHANNEL_COUNTS,
min_octets_per_codec_frame=40, min_octets_per_codec_frame=40,
max_octets_per_codec_frame=40, max_octets_per_codec_frame=40,
supported_max_codec_frames_per_sdu=1, supported_max_codec_frames_per_sdu=1,
@@ -88,7 +88,7 @@ def test_pac_record() -> None:
cap = CodecSpecificCapabilities( cap = CodecSpecificCapabilities(
supported_sampling_frequencies=SAMPLE_FREQUENCY, supported_sampling_frequencies=SAMPLE_FREQUENCY,
supported_frame_durations=FRAME_SURATION, supported_frame_durations=FRAME_SURATION,
supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, supported_audio_channel_count=AUDIO_CHANNEL_COUNTS,
min_octets_per_codec_frame=40, min_octets_per_codec_frame=40,
max_octets_per_codec_frame=40, max_octets_per_codec_frame=40,
supported_max_codec_frames_per_sdu=1, supported_max_codec_frames_per_sdu=1,
@@ -216,7 +216,7 @@ async def test_pacs():
supported_frame_durations=( supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED SupportedFrameDuration.DURATION_10000_US_SUPPORTED
), ),
supported_audio_channel_counts=[1], supported_audio_channel_count=[1],
min_octets_per_codec_frame=40, min_octets_per_codec_frame=40,
max_octets_per_codec_frame=40, max_octets_per_codec_frame=40,
supported_max_codec_frames_per_sdu=1, supported_max_codec_frames_per_sdu=1,
@@ -232,7 +232,7 @@ async def test_pacs():
supported_frame_durations=( supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED SupportedFrameDuration.DURATION_10000_US_SUPPORTED
), ),
supported_audio_channel_counts=[1], supported_audio_channel_count=[1],
min_octets_per_codec_frame=60, min_octets_per_codec_frame=60,
max_octets_per_codec_frame=60, max_octets_per_codec_frame=60,
supported_max_codec_frames_per_sdu=1, supported_max_codec_frames_per_sdu=1,

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import functools
import logging import logging
import os import os
from types import LambdaType from types import LambdaType
@@ -35,12 +36,14 @@ from bumble.hci import (
HCI_COMMAND_STATUS_PENDING, HCI_COMMAND_STATUS_PENDING,
HCI_CREATE_CONNECTION_COMMAND, HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
Address, Address,
OwnAddressType, OwnAddressType,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_Connection_Complete_Event, HCI_Connection_Complete_Event,
HCI_Connection_Request_Event, HCI_Connection_Request_Event,
HCI_Error,
HCI_Packet, HCI_Packet,
) )
from bumble.gatt import ( from bumble.gatt import (
@@ -52,6 +55,10 @@ from bumble.gatt import (
from .test_utils import TwoDevices, async_barrier from .test_utils import TwoDevices, async_barrier
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
_TIMEOUT = 0.1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -391,6 +398,29 @@ async def test_get_remote_le_features():
assert (await devices.connections[0].get_remote_le_features()) is not None assert (await devices.connections[0].get_remote_le_features()) is not None
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_remote_le_features_failed():
devices = TwoDevices()
await devices.setup_connection()
def on_hci_le_read_remote_features_complete_event(event):
devices[0].host.emit(
'le_remote_features_failure',
event.connection_handle,
HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
)
devices[0].host.on_hci_le_read_remote_features_complete_event = (
on_hci_le_read_remote_features_complete_event
)
with pytest.raises(HCI_Error):
await asyncio.wait_for(
devices.connections[0].get_remote_le_features(), _TIMEOUT
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cis(): async def test_cis():
@@ -439,6 +469,65 @@ async def test_cis():
await cis_links[1].disconnect() await cis_links[1].disconnect()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cis_setup_failure():
devices = TwoDevices()
await devices.setup_connection()
cis_requests = asyncio.Queue()
def on_cis_request(
acl_connection: Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
):
del acl_connection, cig_id, cis_id
cis_requests.put_nowait(cis_handle)
devices[1].on('cis_request', on_cis_request)
cis_handles = await devices[0].setup_cig(
cig_id=1,
cis_id=[2],
sdu_interval=(0, 0),
framing=0,
max_sdu=(0, 0),
retransmission_number=0,
max_transport_latency=(0, 0),
)
assert len(cis_handles) == 1
cis_create_task = asyncio.create_task(
devices[0].create_cis(
[
(cis_handles[0], devices.connections[0].handle),
]
)
)
def on_hci_le_cis_established_event(host, event):
host.emit(
'cis_establishment_failure',
event.connection_handle,
HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
)
for device in devices:
device.host.on_hci_le_cis_established_event = functools.partial(
on_hci_le_cis_established_event, device.host
)
cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT)
with pytest.raises(HCI_Error):
await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT)
with pytest.raises(HCI_Error):
await asyncio.wait_for(cis_create_task, _TIMEOUT)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_gatt_services_with_gas(): def test_gatt_services_with_gas():
device = Device(host=Host(None, None)) device = Device(host=Host(None, None))