forked from auracaster/bumble_mirror
ASCS: Add Source ASE operations
This commit is contained in:
115
bumble/device.py
115
bumble/device.py
@@ -23,7 +23,13 @@ import json
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
from contextlib import asynccontextmanager, AsyncExitStack, closing
|
||||
import sys
|
||||
from contextlib import (
|
||||
asynccontextmanager,
|
||||
AsyncExitStack,
|
||||
closing,
|
||||
AbstractAsyncContextManager,
|
||||
)
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Iterable
|
||||
from typing import (
|
||||
@@ -961,8 +967,9 @@ class ScoLink(CompositeEventEmitter):
|
||||
acl_connection: Connection
|
||||
handle: int
|
||||
link_type: int
|
||||
sink: Optional[Callable[[HCI_SynchronousDataPacket], Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def disconnect(
|
||||
@@ -984,8 +991,9 @@ class CisLink(CompositeEventEmitter):
|
||||
cis_id: int # CIS ID assigned by Central device
|
||||
cig_id: int # CIG ID assigned by Central device
|
||||
state: State = State.PENDING
|
||||
sink: Optional[Callable[[HCI_IsoDataPacket], Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def disconnect(
|
||||
@@ -1533,6 +1541,12 @@ class Device(CompositeEventEmitter):
|
||||
Address.ANY: []
|
||||
} # Futures, by BD address OR [Futures] for Address.ANY
|
||||
|
||||
# In Python <= 3.9 + Rust Runtime, asyncio.Lock cannot be properly initiated.
|
||||
if sys.version_info >= (3, 10):
|
||||
self._cis_lock = asyncio.Lock()
|
||||
else:
|
||||
self._cis_lock = AsyncExitStack()
|
||||
|
||||
# Own address type cache
|
||||
self.connect_own_address_type = None
|
||||
|
||||
@@ -3406,49 +3420,71 @@ class Device(CompositeEventEmitter):
|
||||
for cis_handle, _ in cis_acl_pairs
|
||||
}
|
||||
|
||||
@watcher.on(self, 'cis_establishment')
|
||||
def on_cis_establishment(cis_link: CisLink) -> None:
|
||||
if pending_future := pending_cis_establishments.get(cis_link.handle):
|
||||
pending_future.set_result(cis_link)
|
||||
|
||||
result = await self.send_command(
|
||||
def on_cis_establishment_failure(cis_handle: int, status: int) -> None:
|
||||
if pending_future := pending_cis_establishments.get(cis_handle):
|
||||
pending_future.set_exception(HCI_Error(status))
|
||||
|
||||
watcher.on(self, 'cis_establishment', on_cis_establishment)
|
||||
watcher.on(self, 'cis_establishment_failure', on_cis_establishment_failure)
|
||||
await self.send_command(
|
||||
HCI_LE_Create_CIS_Command(
|
||||
cis_connection_handle=[p[0] for p in cis_acl_pairs],
|
||||
acl_connection_handle=[p[1] for p in cis_acl_pairs],
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
'HCI_LE_Create_CIS_Command failed: '
|
||||
f'{HCI_Constant.error_name(result.status)}'
|
||||
)
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
return await asyncio.gather(*pending_cis_establishments.values())
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
async def accept_cis_request(self, handle: int) -> CisLink:
|
||||
result = await self.send_command(
|
||||
HCI_LE_Accept_CIS_Request_Command(connection_handle=handle),
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
'HCI_LE_Accept_CIS_Request_Command failed: '
|
||||
f'{HCI_Constant.error_name(result.status)}'
|
||||
)
|
||||
raise HCI_StatusError(result)
|
||||
"""[LE Only] Accepts an incoming CIS request.
|
||||
|
||||
pending_cis_establishment = asyncio.get_running_loop().create_future()
|
||||
When the specified CIS handle is already created, this method returns the
|
||||
existed CIS link object immediately.
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
Args:
|
||||
handle: CIS handle to accept.
|
||||
|
||||
@watcher.on(self, 'cis_establishment')
|
||||
def on_cis_establishment(cis_link: CisLink) -> None:
|
||||
if cis_link.handle == handle:
|
||||
pending_cis_establishment.set_result(cis_link)
|
||||
Returns:
|
||||
CIS link object on the given handle.
|
||||
"""
|
||||
if not (cis_link := self.cis_links.get(handle)):
|
||||
raise InvalidStateError(f'No pending CIS request of handle {handle}')
|
||||
|
||||
return await pending_cis_establishment
|
||||
# There might be multiple ASE sharing a CIS channel.
|
||||
# If one of them has accepted the request, the others should just leverage it.
|
||||
async with self._cis_lock:
|
||||
if cis_link.state == CisLink.State.ESTABLISHED:
|
||||
return cis_link
|
||||
|
||||
with closing(EventWatcher()) as watcher:
|
||||
pending_establishment = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_establishment() -> None:
|
||||
pending_establishment.set_result(None)
|
||||
|
||||
def on_establishment_failure(status: int) -> None:
|
||||
pending_establishment.set_exception(HCI_Error(status))
|
||||
|
||||
watcher.on(cis_link, 'establishment', on_establishment)
|
||||
watcher.on(cis_link, 'establishment_failure', on_establishment_failure)
|
||||
|
||||
await self.send_command(
|
||||
HCI_LE_Accept_CIS_Request_Command(connection_handle=handle),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
await pending_establishment
|
||||
return cis_link
|
||||
|
||||
# Mypy believes this is reachable when context is an ExitStack.
|
||||
raise InvalidStateError('Unreachable')
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
@@ -3457,15 +3493,10 @@ class Device(CompositeEventEmitter):
|
||||
handle: int,
|
||||
reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
) -> None:
|
||||
result = await self.send_command(
|
||||
await self.send_command(
|
||||
HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason),
|
||||
check_result=True,
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
'HCI_LE_Reject_CIS_Request_Command failed: '
|
||||
f'{HCI_Constant.error_name(result.status)}'
|
||||
)
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
async def get_remote_le_features(self, connection: Connection) -> LeFeatureMask:
|
||||
"""[LE Only] Reads remote LE supported features.
|
||||
@@ -3485,11 +3516,17 @@ class Device(CompositeEventEmitter):
|
||||
if handle == connection.handle:
|
||||
read_feature_future.set_result(LeFeatureMask(features))
|
||||
|
||||
def on_failure(handle: int, status: int):
|
||||
if handle == connection.handle:
|
||||
read_feature_future.set_exception(HCI_Error(status))
|
||||
|
||||
watcher.on(self.host, 'le_remote_features', on_le_remote_features)
|
||||
watcher.on(self.host, 'le_remote_features_failure', on_failure)
|
||||
await self.send_command(
|
||||
HCI_LE_Read_Remote_Features_Command(
|
||||
connection_handle=connection.handle
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
return await read_feature_future
|
||||
|
||||
@@ -4111,8 +4148,8 @@ class Device(CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
@experimental('Only for testing')
|
||||
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
|
||||
if sco_link := self.sco_links.get(sco_handle):
|
||||
sco_link.emit('pdu', packet)
|
||||
if (sco_link := self.sco_links.get(sco_handle)) and sco_link.sink:
|
||||
sco_link.sink(packet)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@@ -4168,15 +4205,15 @@ class Device(CompositeEventEmitter):
|
||||
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
|
||||
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
|
||||
if cis_link := self.cis_links.pop(cis_handle):
|
||||
cis_link.emit('establishment_failure')
|
||||
cis_link.emit('establishment_failure', status)
|
||||
self.emit('cis_establishment_failure', cis_handle, status)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@experimental('Only for testing')
|
||||
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
|
||||
if cis_link := self.cis_links.get(handle):
|
||||
cis_link.emit('pdu', packet)
|
||||
if (cis_link := self.cis_links.get(handle)) and cis_link.sink:
|
||||
cis_link.sink(packet)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -23,7 +23,7 @@ import functools
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, ClassVar
|
||||
|
||||
from bumble import crypto
|
||||
from .colors import color
|
||||
@@ -2003,7 +2003,7 @@ class HCI_Packet:
|
||||
Abstract Base class for HCI packets
|
||||
'''
|
||||
|
||||
hci_packet_type: int
|
||||
hci_packet_type: ClassVar[int]
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet: bytes) -> HCI_Packet:
|
||||
@@ -6192,12 +6192,23 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class HCI_IsoDataPacket(HCI_Packet):
|
||||
'''
|
||||
See Bluetooth spec @ 5.4.5 HCI ISO Data Packets
|
||||
'''
|
||||
|
||||
hci_packet_type = HCI_ISO_DATA_PACKET
|
||||
hci_packet_type: ClassVar[int] = HCI_ISO_DATA_PACKET
|
||||
|
||||
connection_handle: int
|
||||
data_total_length: int
|
||||
iso_sdu_fragment: bytes
|
||||
pb_flag: int
|
||||
ts_flag: int = 0
|
||||
time_stamp: Optional[int] = None
|
||||
packet_sequence_number: Optional[int] = None
|
||||
iso_sdu_length: Optional[int] = None
|
||||
packet_status_flag: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
|
||||
@@ -6241,28 +6252,6 @@ class HCI_IsoDataPacket(HCI_Packet):
|
||||
iso_sdu_fragment=iso_sdu_fragment,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_handle: int,
|
||||
pb_flag: int,
|
||||
ts_flag: int,
|
||||
data_total_length: int,
|
||||
time_stamp: Optional[int],
|
||||
packet_sequence_number: Optional[int],
|
||||
iso_sdu_length: Optional[int],
|
||||
packet_status_flag: Optional[int],
|
||||
iso_sdu_fragment: bytes,
|
||||
) -> None:
|
||||
self.connection_handle = connection_handle
|
||||
self.pb_flag = pb_flag
|
||||
self.ts_flag = ts_flag
|
||||
self.data_total_length = data_total_length
|
||||
self.time_stamp = time_stamp
|
||||
self.packet_sequence_number = packet_sequence_number
|
||||
self.iso_sdu_length = iso_sdu_length
|
||||
self.packet_status_flag = packet_status_flag
|
||||
self.iso_sdu_fragment = iso_sdu_fragment
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
|
||||
@@ -721,14 +721,16 @@ class Host(AbortableEventEmitter):
|
||||
for connection_handle, num_completed_packets in zip(
|
||||
event.connection_handles, event.num_completed_packets
|
||||
):
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
if connection := self.connections.get(connection_handle):
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
elif not (
|
||||
self.cis_links.get(connection_handle)
|
||||
or self.sco_links.get(connection_handle)
|
||||
):
|
||||
logger.warning(
|
||||
'received packet completion event for unknown handle '
|
||||
f'0x{connection_handle:04X}'
|
||||
)
|
||||
continue
|
||||
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
|
||||
# Classic only
|
||||
def on_hci_connection_request_event(self, event):
|
||||
|
||||
@@ -78,6 +78,10 @@ class AudioLocation(enum.IntFlag):
|
||||
LEFT_SURROUND = 0x04000000
|
||||
RIGHT_SURROUND = 0x08000000
|
||||
|
||||
@property
|
||||
def channel_count(self) -> int:
|
||||
return bin(self.value).count('1')
|
||||
|
||||
|
||||
class AudioInputType(enum.IntEnum):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type'''
|
||||
@@ -218,6 +222,13 @@ class FrameDuration(enum.IntEnum):
|
||||
DURATION_7500_US = 0x00
|
||||
DURATION_10000_US = 0x01
|
||||
|
||||
@property
|
||||
def us(self) -> int:
|
||||
return {
|
||||
FrameDuration.DURATION_7500_US: 7500,
|
||||
FrameDuration.DURATION_10000_US: 10000,
|
||||
}[self]
|
||||
|
||||
|
||||
class SupportedFrameDuration(enum.IntFlag):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration'''
|
||||
@@ -534,7 +545,7 @@ class CodecSpecificCapabilities:
|
||||
|
||||
supported_sampling_frequencies: SupportedSamplingFrequency
|
||||
supported_frame_durations: SupportedFrameDuration
|
||||
supported_audio_channel_counts: Sequence[int]
|
||||
supported_audio_channel_count: Sequence[int]
|
||||
min_octets_per_codec_frame: int
|
||||
max_octets_per_codec_frame: int
|
||||
supported_max_codec_frames_per_sdu: int
|
||||
@@ -543,7 +554,7 @@ class CodecSpecificCapabilities:
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities:
|
||||
offset = 0
|
||||
# Allowed default values.
|
||||
supported_audio_channel_counts = [1]
|
||||
supported_audio_channel_count = [1]
|
||||
supported_max_codec_frames_per_sdu = 1
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
@@ -556,7 +567,7 @@ class CodecSpecificCapabilities:
|
||||
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
|
||||
supported_frame_durations = SupportedFrameDuration(value)
|
||||
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
||||
supported_audio_channel_counts = bits_to_channel_counts(value)
|
||||
supported_audio_channel_count = bits_to_channel_counts(value)
|
||||
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
||||
min_octets_per_sample = value & 0xFFFF
|
||||
max_octets_per_sample = value >> 16
|
||||
@@ -567,7 +578,7 @@ class CodecSpecificCapabilities:
|
||||
return CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=supported_sampling_frequencies,
|
||||
supported_frame_durations=supported_frame_durations,
|
||||
supported_audio_channel_counts=supported_audio_channel_counts,
|
||||
supported_audio_channel_count=supported_audio_channel_count,
|
||||
min_octets_per_codec_frame=min_octets_per_sample,
|
||||
max_octets_per_codec_frame=max_octets_per_sample,
|
||||
supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu,
|
||||
@@ -584,7 +595,7 @@ class CodecSpecificCapabilities:
|
||||
self.supported_frame_durations,
|
||||
2,
|
||||
CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT,
|
||||
channel_counts_to_bits(self.supported_audio_channel_counts),
|
||||
channel_counts_to_bits(self.supported_audio_channel_count),
|
||||
5,
|
||||
CodecSpecificCapabilities.Type.OCTETS_PER_FRAME,
|
||||
self.min_octets_per_codec_frame,
|
||||
@@ -870,15 +881,22 @@ class AseStateMachine(gatt.Characteristic):
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
if cis_id == self.cis_id and self.state == self.State.ENABLING:
|
||||
if (
|
||||
cig_id == self.cig_id
|
||||
and cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
acl_connection.abort_on(
|
||||
'flush', self.service.device.accept_cis_request(cis_handle)
|
||||
)
|
||||
|
||||
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
||||
if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING:
|
||||
self.state = self.State.STREAMING
|
||||
self.cis_link = cis_link
|
||||
if (
|
||||
cis_link.cig_id == self.cig_id
|
||||
and cis_link.cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||
|
||||
async def post_cis_established():
|
||||
await self.service.device.send_command(
|
||||
@@ -891,9 +909,15 @@ class AseStateMachine(gatt.Characteristic):
|
||||
codec_configuration=b'',
|
||||
)
|
||||
)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.STREAMING
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
||||
self.cis_link = cis_link
|
||||
|
||||
def on_cis_disconnection(self, _reason) -> None:
|
||||
self.cis_link = None
|
||||
|
||||
def on_config_codec(
|
||||
self,
|
||||
@@ -991,11 +1015,17 @@ class AseStateMachine(gatt.Characteristic):
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.DISABLING
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
else:
|
||||
self.state = self.State.DISABLING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.DISABLING:
|
||||
if (
|
||||
self.role != AudioRole.SOURCE
|
||||
or self.state != AseStateMachine.State.DISABLING
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
@@ -1046,6 +1076,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
def state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
||||
self._state = new_state
|
||||
self.emit('state_change')
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
@@ -1118,6 +1149,7 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
|
||||
ase_state_machines: Dict[int, AseStateMachine]
|
||||
ase_control_point: gatt.Characteristic
|
||||
_active_client: Optional[device.Connection] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1155,7 +1187,16 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
else:
|
||||
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
|
||||
|
||||
def _on_client_disconnected(self, _reason: int) -> None:
|
||||
for ase in self.ase_state_machines.values():
|
||||
ase.state = AseStateMachine.State.IDLE
|
||||
self._active_client = None
|
||||
|
||||
def on_write_ase_control_point(self, connection, data):
|
||||
if not self._active_client and connection:
|
||||
self._active_client = connection
|
||||
connection.once('disconnection', self._on_client_disconnected)
|
||||
|
||||
operation = ASE_Operation.from_bytes(data)
|
||||
responses = []
|
||||
logger.debug(f'*** ASCS Write {operation} ***')
|
||||
|
||||
Reference in New Issue
Block a user