Unify ISO methods

This commit is contained in:
Josh Wu
2024-12-23 18:18:32 +08:00
parent c3b2bb19d5
commit 27fcd43224
4 changed files with 140 additions and 107 deletions

View File

@@ -712,16 +712,8 @@ async def run_receive(
sdus = [b''] * num_bis sdus = [b''] * num_bis
bis_link.sink = functools.partial(sink, i) bis_link.sink = functools.partial(sink, i)
await device.send_command( await bis_link.setup_data_path(
hci.HCI_LE_Setup_ISO_Data_Path_Command( direction=bis_link.Direction.CONTROLLER_TO_HOST
connection_handle=bis_link.handle,
data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST,
data_path_id=0,
codec_id=hci.CodingFormat(codec_id=hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
) )
terminated = asyncio.Event() terminated = asyncio.Event()
@@ -834,16 +826,8 @@ async def run_broadcast(
) )
print('Setup ISO Data Path') print('Setup ISO Data Path')
for bis_link in big.bis_links: for bis_link in big.bis_links:
await device.send_command( await bis_link.setup_data_path(
hci.HCI_LE_Setup_ISO_Data_Path_Command( direction=bis_link.Direction.HOST_TO_CONTROLLER
connection_handle=bis_link.handle,
data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER,
data_path_id=0,
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
),
check_result=True,
) )
for frame in itertools.cycle(frames): for frame in itertools.cycle(frames):

View File

@@ -39,7 +39,7 @@ import aiohttp.web
import bumble import bumble
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
@@ -110,7 +110,7 @@ async def lc3_source_task(
sdu_length: int, sdu_length: int,
frame_duration_us: int, frame_duration_us: int,
device: Device, device: Device,
cis_handle: int, cis_link: CisLink,
) -> None: ) -> None:
logger.info( logger.info(
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f", "lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
@@ -120,7 +120,6 @@ async def lc3_source_task(
) )
with wave.open(filename, 'rb') as wav: with wave.open(filename, 'rb') as wav:
bits_per_sample = wav.getsampwidth() * 8 bits_per_sample = wav.getsampwidth() * 8
packet_sequence_number = 0
encoder: lc3.Encoder | None = None encoder: lc3.Encoder | None = None
@@ -150,18 +149,8 @@ async def lc3_source_task(
num_bytes=sdu_length, num_bytes=sdu_length,
bit_depth=bits_per_sample, bit_depth=bits_per_sample,
) )
cis_link.write(sdu)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now() sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds() * 0.9) await asyncio.sleep(sleep_time.total_seconds() * 0.9)
@@ -309,6 +298,7 @@ class Speaker:
advertising_interval_min=25, advertising_interval_min=25,
advertising_interval_max=25, advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'), address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
) )
device_config.le_enabled = True device_config.le_enabled = True
@@ -393,7 +383,7 @@ class Speaker:
), ),
frame_duration_us=codec_config.frame_duration.us, frame_duration_us=codec_config.frame_duration.us,
device=self.device, device=self.device,
cis_handle=ase.cis_link.handle, cis_link=ase.cis_link,
), ),
) )
else: else:

View File

@@ -37,9 +37,7 @@ from typing import (
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
List,
Optional, Optional,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@@ -1145,7 +1143,7 @@ class Peer:
connection.gatt_client = self.gatt_client connection.gatt_client = self.gatt_client
@property @property
def services(self) -> List[gatt_client.ServiceProxy]: def services(self) -> list[gatt_client.ServiceProxy]:
return self.gatt_client.services return self.gatt_client.services
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
@@ -1155,24 +1153,24 @@ class Peer:
async def discover_service( async def discover_service(
self, uuid: Union[core.UUID, str] self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]: ) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid) return await self.gatt_client.discover_service(uuid)
async def discover_services( async def discover_services(
self, uuids: Iterable[core.UUID] = () self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]: ) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids) return await self.gatt_client.discover_services(uuids)
async def discover_included_services( async def discover_included_services(
self, service: gatt_client.ServiceProxy self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]: ) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service) return await self.gatt_client.discover_included_services(service)
async def discover_characteristics( async def discover_characteristics(
self, self,
uuids: Iterable[Union[core.UUID, str]] = (), uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None, service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]: ) -> list[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics( return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service uuids=uuids, service=service
) )
@@ -1187,7 +1185,7 @@ class Peer:
characteristic, start_handle, end_handle characteristic, start_handle, end_handle
) )
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]: async def discover_attributes(self) -> list[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes() return await self.gatt_client.discover_attributes()
async def discover_all(self): async def discover_all(self):
@@ -1231,17 +1229,17 @@ class Peer:
async def read_characteristics_by_uuid( async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]: ) -> list[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service) return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]: def get_services_by_uuid(self, uuid: core.UUID) -> list[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid) return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid( def get_characteristics_by_uuid(
self, self,
uuid: core.UUID, uuid: core.UUID,
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None, service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
) -> List[gatt_client.CharacteristicProxy]: ) -> list[gatt_client.CharacteristicProxy]:
if isinstance(service, core.UUID): if isinstance(service, core.UUID):
return list( return list(
itertools.chain( itertools.chain(
@@ -1328,41 +1326,75 @@ class ScoLink(CompositeEventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass class _IsoLink:
class CisLink(CompositeEventEmitter):
class State(IntEnum):
PENDING = 0
ESTABLISHED = 1
device: Device
acl_connection: Connection # Based ACL connection
handle: int # CIS handle assigned by Controller (in LE_Set_CIG_Parameters Complete or LE_CIS_Request events)
cis_id: int # CIS ID assigned by Central device
cig_id: int # CIG ID assigned by Central device
state: State = State.PENDING
sink: Optional[Callable[[hci.HCI_IsoDataPacket], Any]] = None
def __post_init__(self) -> None:
super().__init__()
async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
# -----------------------------------------------------------------------------
@dataclass
class BisLink:
handle: int handle: int
big: Big | BigSync device: Device
packet_sequence_number: int
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None: class Direction(IntEnum):
self.device = self.big.device HOST_TO_CONTROLLER = (
self.packet_sequence_number = 0 hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
)
CONTROLLER_TO_HOST = (
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
)
async def setup_data_path(
self,
direction: _IsoLink.Direction,
data_path_id: int = 0,
codec_id: hci.CodingFormat | None = None,
controller_delay: int = 0,
codec_configuration: bytes = b'',
) -> None:
"""Create a data path between controller and given entry.
Args:
direction: Direction of data path.
data_path_id: ID of data path. Default is 0 (HCI).
codec_id: Codec ID. Default is Transparent.
controller_delay: Controller delay in microseconds. Default is 0.
codec_configuration: Codec-specific configuration.
Raises:
HCI_Error: When command complete status is not HCI_SUCCESS.
"""
await self.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
data_path_id=data_path_id,
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=controller_delay,
codec_configuration=codec_configuration,
),
check_result=True,
)
async def remove_data_path(self, direction: _IsoLink.Direction) -> int:
"""Remove a data path with controller on given direction.
Args:
direction: Direction of data path.
Returns:
Command status.
"""
response = await self.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
),
check_result=False,
)
return response.return_parameters.status
def write(self, sdu: bytes) -> None: def write(self, sdu: bytes) -> None:
"""Write an ISO SDU.
This will automatically increase the packet sequence number.
"""
self.device.host.send_hci_packet( self.device.host.send_hci_packet(
hci.HCI_IsoDataPacket( hci.HCI_IsoDataPacket(
connection_handle=self.handle, connection_handle=self.handle,
@@ -1377,6 +1409,43 @@ class BisLink:
self.packet_sequence_number += 1 self.packet_sequence_number += 1
# -----------------------------------------------------------------------------
@dataclass
class CisLink(CompositeEventEmitter, _IsoLink):
class State(IntEnum):
PENDING = 0
ESTABLISHED = 1
device: Device
acl_connection: Connection # Based ACL connection
handle: int # CIS handle assigned by Controller (in LE_Set_CIG_Parameters Complete or LE_CIS_Request events)
cis_id: int # CIS ID assigned by Central device
cig_id: int # CIG ID assigned by Central device
state: State = State.PENDING
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
super().__init__()
self.packet_sequence_number = 0
async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
# -----------------------------------------------------------------------------
@dataclass
class BisLink(_IsoLink):
handle: int
big: Big | BigSync
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
self.device = self.big.device
self.packet_sequence_number = 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection(CompositeEventEmitter): class Connection(CompositeEventEmitter):
device: Device device: Device
@@ -1689,7 +1758,7 @@ class DeviceConfiguration:
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = [] self.gatt_services: list[Dict[str, Any]] = []
def load_from_dict(self, config: Dict[str, Any]) -> None: def load_from_dict(self, config: Dict[str, Any]) -> None:
config = copy.deepcopy(config) config = copy.deepcopy(config)
@@ -1836,7 +1905,7 @@ def host_event_handler(function):
# List of host event handlers for the Device class. # List of host event handlers for the Device class.
# (we define this list outside the class, because referencing a class in method # (we define this list outside the class, because referencing a class in method
# decorators is not straightforward) # decorators is not straightforward)
device_host_event_handlers: List[str] = [] device_host_event_handlers: list[str] = []
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1857,10 +1926,10 @@ class Device(CompositeEventEmitter):
pending_connections: Dict[hci.Address, Connection] pending_connections: Dict[hci.Address, Connection]
classic_pending_accepts: Dict[ classic_pending_accepts: Dict[
hci.Address, hci.Address,
List[asyncio.Future[Union[Connection, Tuple[hci.Address, int, int]]]], list[asyncio.Future[Union[Connection, tuple[hci.Address, int, int]]]],
] ]
advertisement_accumulators: Dict[hci.Address, AdvertisementDataAccumulator] advertisement_accumulators: Dict[hci.Address, AdvertisementDataAccumulator]
periodic_advertising_syncs: List[PeriodicAdvertisingSync] periodic_advertising_syncs: list[PeriodicAdvertisingSync]
config: DeviceConfiguration config: DeviceConfiguration
legacy_advertiser: Optional[LegacyAdvertiser] legacy_advertiser: Optional[LegacyAdvertiser]
sco_links: Dict[int, ScoLink] sco_links: Dict[int, ScoLink]
@@ -1868,7 +1937,7 @@ class Device(CompositeEventEmitter):
bigs = dict[int, Big]() bigs = dict[int, Big]()
bis_links = dict[int, BisLink]() bis_links = dict[int, BisLink]()
big_syncs = dict[int, BigSync]() big_syncs = dict[int, BigSync]()
_pending_cis: Dict[int, Tuple[int, int]] _pending_cis: Dict[int, tuple[int, int]]
@composite_listener @composite_listener
class Listener: class Listener:
@@ -2793,7 +2862,7 @@ class Device(CompositeEventEmitter):
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type: int = hci.OwnAddressType.RANDOM, own_address_type: int = hci.OwnAddressType.RANDOM,
filter_duplicates: bool = False, filter_duplicates: bool = False,
scanning_phys: List[int] = [hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY], scanning_phys: Sequence[int] = (hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY),
) -> None: ) -> None:
# Check that the arguments are legal # Check that the arguments are legal
if scan_interval < scan_window: if scan_interval < scan_window:
@@ -4124,13 +4193,13 @@ class Device(CompositeEventEmitter):
async def setup_cig( async def setup_cig(
self, self,
cig_id: int, cig_id: int,
cis_id: List[int], cis_id: Sequence[int],
sdu_interval: Tuple[int, int], sdu_interval: tuple[int, int],
framing: int, framing: int,
max_sdu: Tuple[int, int], max_sdu: tuple[int, int],
retransmission_number: int, retransmission_number: int,
max_transport_latency: Tuple[int, int], max_transport_latency: tuple[int, int],
) -> List[int]: ) -> list[int]:
"""Sends hci.HCI_LE_Set_CIG_Parameters_Command. """Sends hci.HCI_LE_Set_CIG_Parameters_Command.
Args: Args:
@@ -4179,7 +4248,9 @@ class Device(CompositeEventEmitter):
# [LE only] # [LE only]
@experimental('Only for testing.') @experimental('Only for testing.')
async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink]: async def create_cis(
self, cis_acl_pairs: Sequence[tuple[int, int]]
) -> list[CisLink]:
for cis_handle, acl_handle in cis_acl_pairs: for cis_handle, acl_handle in cis_acl_pairs:
acl_connection = self.lookup_connection(acl_handle) acl_connection = self.lookup_connection(acl_handle)
assert acl_connection assert acl_connection
@@ -4504,7 +4575,7 @@ class Device(CompositeEventEmitter):
self, self,
status: int, status: int,
big_handle: int, big_handle: int,
bis_handles: List[int], bis_handles: list[int],
big_sync_delay: int, big_sync_delay: int,
transport_latency_big: int, transport_latency_big: int,
phy: int, phy: int,

View File

@@ -17,6 +17,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import logging import logging
import struct import struct
@@ -258,8 +259,8 @@ class AseReasonCode(enum.IntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum): class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -354,16 +355,7 @@ class AseStateMachine(gatt.Characteristic):
cis_link.on('disconnection', self.on_cis_disconnection) cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established(): async def post_cis_established():
await self.service.device.send_command( await cis_link.setup_data_path(direction=self.role)
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK: if self.role == AudioRole.SINK:
self.state = self.State.STREAMING self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
@@ -511,12 +503,8 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.RELEASING self.state = self.State.RELEASING
async def remove_cis_async(): async def remove_cis_async():
await self.service.device.send_command( if self.cis_link:
hci.HCI_LE_Remove_ISO_Data_Path_Command( await self.cis_link.remove_data_path(self.role)
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)