Unify ISO methods

This commit is contained in:
Josh Wu
2024-12-23 18:18:32 +08:00
parent c3b2bb19d5
commit 27fcd43224
4 changed files with 140 additions and 107 deletions

View File

@@ -37,9 +37,7 @@ from typing import (
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
@@ -1145,7 +1143,7 @@ class Peer:
connection.gatt_client = self.gatt_client
@property
def services(self) -> List[gatt_client.ServiceProxy]:
def services(self) -> list[gatt_client.ServiceProxy]:
return self.gatt_client.services
async def request_mtu(self, mtu: int) -> int:
@@ -1155,24 +1153,24 @@ class Peer:
async def discover_service(
self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]:
) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid)
async def discover_services(
self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]:
) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids)
async def discover_included_services(
self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]:
) -> list[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service)
async def discover_characteristics(
self,
uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]:
) -> list[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service
)
@@ -1187,7 +1185,7 @@ class Peer:
characteristic, start_handle, end_handle
)
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
async def discover_attributes(self) -> list[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes()
async def discover_all(self):
@@ -1231,17 +1229,17 @@ class Peer:
async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]:
) -> list[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
def get_services_by_uuid(self, uuid: core.UUID) -> list[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid(
self,
uuid: core.UUID,
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
) -> List[gatt_client.CharacteristicProxy]:
) -> list[gatt_client.CharacteristicProxy]:
if isinstance(service, core.UUID):
return list(
itertools.chain(
@@ -1328,41 +1326,75 @@ class ScoLink(CompositeEventEmitter):
# -----------------------------------------------------------------------------
@dataclass
class CisLink(CompositeEventEmitter):
class State(IntEnum):
PENDING = 0
ESTABLISHED = 1
device: Device
acl_connection: Connection # Based ACL connection
handle: int # CIS handle assigned by Controller (in LE_Set_CIG_Parameters Complete or LE_CIS_Request events)
cis_id: int # CIS ID assigned by Central device
cig_id: int # CIG ID assigned by Central device
state: State = State.PENDING
sink: Optional[Callable[[hci.HCI_IsoDataPacket], Any]] = None
def __post_init__(self) -> None:
super().__init__()
async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
# -----------------------------------------------------------------------------
@dataclass
class BisLink:
class _IsoLink:
handle: int
big: Big | BigSync
device: Device
packet_sequence_number: int
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
self.device = self.big.device
self.packet_sequence_number = 0
class Direction(IntEnum):
HOST_TO_CONTROLLER = (
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
)
CONTROLLER_TO_HOST = (
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
)
async def setup_data_path(
self,
direction: _IsoLink.Direction,
data_path_id: int = 0,
codec_id: hci.CodingFormat | None = None,
controller_delay: int = 0,
codec_configuration: bytes = b'',
) -> None:
"""Create a data path between controller and given entry.
Args:
direction: Direction of data path.
data_path_id: ID of data path. Default is 0 (HCI).
codec_id: Codec ID. Default is Transparent.
controller_delay: Controller delay in microseconds. Default is 0.
codec_configuration: Codec-specific configuration.
Raises:
HCI_Error: When command complete status is not HCI_SUCCESS.
"""
await self.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
data_path_id=data_path_id,
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=controller_delay,
codec_configuration=codec_configuration,
),
check_result=True,
)
async def remove_data_path(self, direction: _IsoLink.Direction) -> int:
"""Remove a data path with controller on given direction.
Args:
direction: Direction of data path.
Returns:
Command status.
"""
response = await self.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
),
check_result=False,
)
return response.return_parameters.status
def write(self, sdu: bytes) -> None:
"""Write an ISO SDU.
This will automatically increase the packet sequence number.
"""
self.device.host.send_hci_packet(
hci.HCI_IsoDataPacket(
connection_handle=self.handle,
@@ -1377,6 +1409,43 @@ class BisLink:
self.packet_sequence_number += 1
# -----------------------------------------------------------------------------
@dataclass
class CisLink(CompositeEventEmitter, _IsoLink):
class State(IntEnum):
PENDING = 0
ESTABLISHED = 1
device: Device
acl_connection: Connection # Based ACL connection
handle: int # CIS handle assigned by Controller (in LE_Set_CIG_Parameters Complete or LE_CIS_Request events)
cis_id: int # CIS ID assigned by Central device
cig_id: int # CIG ID assigned by Central device
state: State = State.PENDING
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
super().__init__()
self.packet_sequence_number = 0
async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
# -----------------------------------------------------------------------------
@dataclass
class BisLink(_IsoLink):
handle: int
big: Big | BigSync
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
self.device = self.big.device
self.packet_sequence_number = 0
# -----------------------------------------------------------------------------
class Connection(CompositeEventEmitter):
device: Device
@@ -1689,7 +1758,7 @@ class DeviceConfiguration:
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
self.gatt_services: list[Dict[str, Any]] = []
def load_from_dict(self, config: Dict[str, Any]) -> None:
config = copy.deepcopy(config)
@@ -1836,7 +1905,7 @@ def host_event_handler(function):
# List of host event handlers for the Device class.
# (we define this list outside the class, because referencing a class in method
# decorators is not straightforward)
device_host_event_handlers: List[str] = []
device_host_event_handlers: list[str] = []
# -----------------------------------------------------------------------------
@@ -1857,10 +1926,10 @@ class Device(CompositeEventEmitter):
pending_connections: Dict[hci.Address, Connection]
classic_pending_accepts: Dict[
hci.Address,
List[asyncio.Future[Union[Connection, Tuple[hci.Address, int, int]]]],
list[asyncio.Future[Union[Connection, tuple[hci.Address, int, int]]]],
]
advertisement_accumulators: Dict[hci.Address, AdvertisementDataAccumulator]
periodic_advertising_syncs: List[PeriodicAdvertisingSync]
periodic_advertising_syncs: list[PeriodicAdvertisingSync]
config: DeviceConfiguration
legacy_advertiser: Optional[LegacyAdvertiser]
sco_links: Dict[int, ScoLink]
@@ -1868,7 +1937,7 @@ class Device(CompositeEventEmitter):
bigs = dict[int, Big]()
bis_links = dict[int, BisLink]()
big_syncs = dict[int, BigSync]()
_pending_cis: Dict[int, Tuple[int, int]]
_pending_cis: Dict[int, tuple[int, int]]
@composite_listener
class Listener:
@@ -2793,7 +2862,7 @@ class Device(CompositeEventEmitter):
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type: int = hci.OwnAddressType.RANDOM,
filter_duplicates: bool = False,
scanning_phys: List[int] = [hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY],
scanning_phys: Sequence[int] = (hci.HCI_LE_1M_PHY, hci.HCI_LE_CODED_PHY),
) -> None:
# Check that the arguments are legal
if scan_interval < scan_window:
@@ -4124,13 +4193,13 @@ class Device(CompositeEventEmitter):
async def setup_cig(
self,
cig_id: int,
cis_id: List[int],
sdu_interval: Tuple[int, int],
cis_id: Sequence[int],
sdu_interval: tuple[int, int],
framing: int,
max_sdu: Tuple[int, int],
max_sdu: tuple[int, int],
retransmission_number: int,
max_transport_latency: Tuple[int, int],
) -> List[int]:
max_transport_latency: tuple[int, int],
) -> list[int]:
"""Sends hci.HCI_LE_Set_CIG_Parameters_Command.
Args:
@@ -4179,7 +4248,9 @@ class Device(CompositeEventEmitter):
# [LE only]
@experimental('Only for testing.')
async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink]:
async def create_cis(
self, cis_acl_pairs: Sequence[tuple[int, int]]
) -> list[CisLink]:
for cis_handle, acl_handle in cis_acl_pairs:
acl_connection = self.lookup_connection(acl_handle)
assert acl_connection
@@ -4504,7 +4575,7 @@ class Device(CompositeEventEmitter):
self,
status: int,
big_handle: int,
bis_handles: List[int],
bis_handles: list[int],
big_sync_delay: int,
transport_latency_big: int,
phy: int,