From 58c9c4f590735f4162f117b2b36c56aac36e3178 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Wed, 29 Nov 2023 19:19:40 -0800 Subject: [PATCH 1/8] fix #354 --- bumble/gatt_client.py | 40 ++++++++++++++++------ tests/gatt_test.py | 79 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 0c69b12..d5a8ec7 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -207,11 +207,11 @@ class CharacteristicProxy(AttributeProxy): return await self.client.subscribe(self, subscriber, prefer_notify) - async def unsubscribe(self, subscriber=None): + async def unsubscribe(self, subscriber=None, force=False): if subscriber in self.subscribers: subscriber = self.subscribers.pop(subscriber) - return await self.client.unsubscribe(self, subscriber) + return await self.client.unsubscribe(self, subscriber, force) def __str__(self) -> str: return ( @@ -262,10 +262,8 @@ class Client: self.request_semaphore = asyncio.Semaphore(1) self.pending_request = None self.pending_response = None - self.notification_subscribers = ( - {} - ) # Notification subscribers, by attribute handle - self.indication_subscribers = {} # Indication subscribers, by attribute handle + self.notification_subscribers = {} # Subscriber set, by attribute handle + self.indication_subscribers = {} # Subscriber set, by attribute handle self.services = [] self.cached_values = {} @@ -836,6 +834,7 @@ class Client: subscriber_set = subscribers.setdefault(characteristic.handle, set()) if subscriber is not None: subscriber_set.add(subscriber) + # Add the characteristic as a subscriber, which will result in the # characteristic emitting an 'update' event when a notification or indication # is received @@ -847,7 +846,14 @@ class Client: self, characteristic: CharacteristicProxy, subscriber: Optional[Callable[[bytes], Any]] = None, + force: bool = False, ) -> None: + ''' + Unsubscribe from a characteristic. + + If `force` is True, this will write zeros to the CCCD when there are no + subscribers left, even if there were already no registered subscribers. + ''' # If we haven't already discovered the descriptors for this characteristic, # do it now if not characteristic.descriptors_discovered: @@ -861,25 +867,39 @@ class Client: logger.warning('unsubscribing from characteristic with no CCCD descriptor') return + # Check if the characteristic has subscribers + if not ( + characteristic.handle in self.notification_subscribers + or characteristic.handle in self.indication_subscribers + ): + if not force: + return + + # Remove the subscriber(s) if subscriber is not None: # Remove matching subscriber from subscriber sets for subscriber_set in ( self.notification_subscribers, self.indication_subscribers, ): - subscribers = subscriber_set.get(characteristic.handle, set()) - if subscriber in subscribers: + if ( + subscribers := subscriber_set.get(characteristic.handle) + ) and subscriber in subscribers: subscribers.remove(subscriber) # Cleanup if we removed the last one if not subscribers: del subscriber_set[characteristic.handle] else: - # Remove all subscribers for this attribute from the sets! + # Remove all subscribers for this attribute from the sets self.notification_subscribers.pop(characteristic.handle, None) self.indication_subscribers.pop(characteristic.handle, None) - if not self.notification_subscribers and not self.indication_subscribers: + # Update the CCCD + if not ( + characteristic.handle in self.notification_subscribers + or characteristic.handle in self.indication_subscribers + ): # No more subscribers left await self.write_value(cccd, b'\x00\x00', with_response=True) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index d9f6d60..85b40a9 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -20,6 +20,7 @@ import logging import os import struct import pytest +from unittest.mock import Mock, ANY from bumble.controller import Controller from bumble.gatt_client import CharacteristicProxy @@ -763,6 +764,83 @@ async def test_subscribe_notify(): assert not c3._called_3 +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_unsubscribe(): + [client, server] = LinkedDevices().devices[:2] + + characteristic1 = Characteristic( + 'FDB159DB-036C-49E3-B3DB-6325AC750806', + Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, + Characteristic.READABLE, + bytes([1, 2, 3]), + ) + characteristic2 = Characteristic( + '3234C4F4-3F34-4616-8935-45A50EE05DEB', + Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, + Characteristic.READABLE, + bytes([1, 2, 3]), + ) + + service1 = Service( + '3A657F47-D34F-46B3-B1EC-698E29B6B829', + [characteristic1, characteristic2], + ) + server.add_services([service1]) + + mock1 = Mock() + characteristic1.on('subscription', mock1) + mock2 = Mock() + characteristic2.on('subscription', mock2) + + await client.power_on() + await server.power_on() + connection = await client.connect(server.random_address) + peer = Peer(connection) + + await peer.discover_services() + await peer.discover_characteristics() + c = peer.get_characteristics_by_uuid(characteristic1.uuid) + assert len(c) == 1 + c1 = c[0] + c = peer.get_characteristics_by_uuid(characteristic2.uuid) + assert len(c) == 1 + c2 = c[0] + + await c1.subscribe() + await async_barrier() + mock1.assert_called_once_with(ANY, True, False) + + await c2.subscribe() + await async_barrier() + mock2.assert_called_once_with(ANY, True, False) + + mock1.reset_mock() + await c1.unsubscribe() + await async_barrier() + mock1.assert_called_once_with(ANY, False, False) + + mock2.reset_mock() + await c2.unsubscribe() + await async_barrier() + mock2.assert_called_once_with(ANY, False, False) + + mock1.reset_mock() + await c1.unsubscribe() + await async_barrier() + mock1.assert_not_called() + + mock2.reset_mock() + await c2.unsubscribe() + await async_barrier() + mock2.assert_not_called() + + mock1.reset_mock() + await c1.unsubscribe(force=True) + await async_barrier() + mock1.assert_called_once_with(ANY, False, False) + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_mtu_exchange(): @@ -886,6 +964,7 @@ async def async_main(): await test_read_write() await test_read_write2() await test_subscribe_notify() + await test_unsubscribe() await test_characteristic_encoding() await test_mtu_exchange() From c5def93bb81efeef066784d16493d35d0a0963c3 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 27 Nov 2023 19:22:14 +0800 Subject: [PATCH 2/8] CIS and SCO responder support --- bumble/device.py | 366 +++++++++++++++++++++++++++++++++++--- bumble/hci.py | 6 +- bumble/host.py | 30 +++- examples/leaudio.json | 5 + examples/run_cig_setup.py | 105 +++++++++++ 5 files changed, 483 insertions(+), 29 deletions(-) create mode 100644 examples/leaudio.json create mode 100644 examples/run_cig_setup.py diff --git a/bumble/device.py b/bumble/device.py index 37f1610..75caece 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -21,7 +21,7 @@ import functools import json import asyncio import logging -from contextlib import asynccontextmanager, AsyncExitStack +from contextlib import asynccontextmanager, AsyncExitStack, closing from dataclasses import dataclass from collections.abc import Iterable from typing import ( @@ -49,6 +49,7 @@ from .hci import ( HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, HCI_CENTRAL_ROLE, HCI_COMMAND_STATUS_PENDING, + HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE, HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR, HCI_DISPLAY_YES_NO_IO_CAPABILITY, HCI_DISPLAY_ONLY_IO_CAPABILITY, @@ -85,29 +86,35 @@ from .hci import ( HCI_Constant, HCI_Create_Connection_Cancel_Command, HCI_Create_Connection_Command, + HCI_Create_Connection_Command, HCI_Disconnect_Command, HCI_Encryption_Change_Event, HCI_Error, HCI_IO_Capability_Request_Reply_Command, HCI_Inquiry_Cancel_Command, HCI_Inquiry_Command, + HCI_IsoDataPacket, + HCI_LE_Accept_CIS_Request_Command, HCI_LE_Add_Device_To_Resolving_List_Command, HCI_LE_Advertising_Report_Event, HCI_LE_Clear_Resolving_List_Command, HCI_LE_Connection_Update_Command, HCI_LE_Create_Connection_Cancel_Command, HCI_LE_Create_Connection_Command, + HCI_LE_Create_CIS_Command, HCI_LE_Enable_Encryption_Command, HCI_LE_Extended_Advertising_Report_Event, HCI_LE_Extended_Create_Connection_Command, HCI_LE_Rand_Command, HCI_LE_Read_PHY_Command, + HCI_LE_Reject_CIS_Request_Command, HCI_LE_Remove_Advertising_Set_Command, HCI_LE_Set_Address_Resolution_Enable_Command, HCI_LE_Set_Advertising_Data_Command, HCI_LE_Set_Advertising_Enable_Command, HCI_LE_Set_Advertising_Parameters_Command, HCI_LE_Set_Advertising_Set_Random_Address_Command, + HCI_LE_Set_CIG_Parameters_Command, HCI_LE_Set_Data_Length_Command, HCI_LE_Set_Default_PHY_Command, HCI_LE_Set_Extended_Scan_Enable_Command, @@ -116,6 +123,7 @@ from .hci import ( HCI_LE_Set_Extended_Advertising_Data_Command, HCI_LE_Set_Extended_Advertising_Enable_Command, HCI_LE_Set_Extended_Advertising_Parameters_Command, + HCI_LE_Set_Host_Feature_Command, HCI_LE_Set_PHY_Command, HCI_LE_Set_Random_Address_Command, HCI_LE_Set_Scan_Enable_Command, @@ -130,6 +138,7 @@ from .hci import ( HCI_Switch_Role_Command, HCI_Set_Connection_Encryption_Command, HCI_StatusError, + HCI_SynchronousDataPacket, HCI_User_Confirmation_Request_Negative_Reply_Command, HCI_User_Confirmation_Request_Reply_Command, HCI_User_Passkey_Request_Negative_Reply_Command, @@ -161,6 +170,7 @@ from .core import ( from .utils import ( AsyncRunner, CompositeEventEmitter, + EventWatcher, setup_event_forwarding, composite_listener, deprecated, @@ -592,6 +602,46 @@ class ConnectionParametersPreferences: ConnectionParametersPreferences.default = ConnectionParametersPreferences() +# ----------------------------------------------------------------------------- +@dataclass +class ScoLink(CompositeEventEmitter): + device: Device + acl_connection: Connection + handle: int + link_type: int + + def __post_init__(self): + super().__init__() + + async def disconnect( + self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR + ) -> None: + await self.device.disconnect(self, reason) + + +# ----------------------------------------------------------------------------- +@dataclass +class CisLink(CompositeEventEmitter): + class State(IntEnum): + PENDING = 0 + ESTABLISHED = 1 + + device: Device + acl_connection: Connection # Based ACL connection + handle: int # CIS handle assigned by Controller (in LE_Set_CIG_Parameters Complete or LE_CIS_Request events) + cis_id: int # CIS ID assigned by Central device + cig_id: int # CIG ID assigned by Central device + state: State = State.PENDING + + def __post_init__(self): + super().__init__() + + async def disconnect( + self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR + ) -> None: + await self.device.disconnect(self, reason) + + # ----------------------------------------------------------------------------- class Connection(CompositeEventEmitter): device: Device @@ -870,6 +920,7 @@ class DeviceConfiguration: self.keystore = None self.gatt_services: List[Dict[str, Any]] = [] self.address_resolution_offload = False + self.cis_enabled = False def load_from_dict(self, config: Dict[str, Any]) -> None: # Load simple properties @@ -905,6 +956,7 @@ class DeviceConfiguration: self.address_resolution_offload = config.get( 'address_resolution_offload', self.address_resolution_offload ) + self.cis_enabled = config.get('cis_enabled', self.cis_enabled) # Load or synthesize an IRK irk = config.get('irk') @@ -1012,6 +1064,9 @@ class Device(CompositeEventEmitter): advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] config: DeviceConfiguration extended_advertising_handles: Set[int] + sco_links: Dict[int, ScoLink] + cis_links: Dict[int, CisLink] + _pending_cis: Dict[int, Tuple[int, int]] @composite_listener class Listener: @@ -1104,6 +1159,9 @@ class Device(CompositeEventEmitter): self.disconnecting = False self.connections = {} # Connections, by connection handle self.pending_connections = {} # Connections, by BD address (BR/EDR only) + self.sco_links = {} # ScoLinks, by connection handle (BR/EDR only) + self.cis_links = {} # CisLinks, by connection handle (LE only) + self._pending_cis = {} # (CIS_ID, CIG_ID), by CIS_handle self.classic_enabled = False self.inquiry_response = None self.address_resolver = None @@ -1133,6 +1191,7 @@ class Device(CompositeEventEmitter): self.le_enabled = config.le_enabled self.classic_enabled = config.classic_enabled self.le_simultaneous_enabled = config.le_simultaneous_enabled + self.cis_enabled = config.cis_enabled self.classic_sc_enabled = config.classic_sc_enabled self.classic_ssp_enabled = config.classic_ssp_enabled self.classic_smp_enabled = config.classic_smp_enabled @@ -1443,6 +1502,16 @@ class Device(CompositeEventEmitter): ) # type: ignore[call-arg] ) + if self.cis_enabled: + await self.send_command( + HCI_LE_Set_Host_Feature_Command( # type: ignore[call-arg] + bit_number=( + HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE + ), + bit_value=1, + ) + ) + if self.classic_enabled: await self.send_command( HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg] @@ -2366,7 +2435,9 @@ class Device(CompositeEventEmitter): check_result=True, ) - async def disconnect(self, connection, reason): + async def disconnect( + self, connection: Union[Connection, ScoLink, CisLink], reason: int + ) -> None: # Create a future so that we can wait for the disconnection's result pending_disconnection = asyncio.get_running_loop().create_future() connection.on('disconnection', pending_disconnection.set_result) @@ -2374,7 +2445,7 @@ class Device(CompositeEventEmitter): # Request a disconnection result = await self.send_command( - HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) + HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) # type: ignore[call-arg] ) try: @@ -2837,6 +2908,154 @@ class Device(CompositeEventEmitter): self.remove_listener('remote_name', handler) self.remove_listener('remote_name_failure', failure_handler) + # [LE only] + @experimental('Only for testing.') + async def setup_cig( + self, + cig_id: int, + cis_id: List[int], + sdu_interval: Tuple[int, int], + framing: int, + max_sdu: Tuple[int, int], + retransmission_number: int, + max_transport_latency: Tuple[int, int], + ) -> List[int]: + """Sends HCI_LE_Set_CIG_Parameters_Command. + + Args: + cig_id: CIG_ID. + cis_id: CID ID list. + sdu_interval: SDU intervals of (Central->Peripheral, Peripheral->Cental). + framing: Un-framing(0) or Framing(1). + max_sdu: Max SDU counts of (Central->Peripheral, Peripheral->Cental). + retransmission_number: retransmission_number. + max_transport_latency: Max transport latencies of + (Central->Peripheral, Peripheral->Cental). + + Returns: + List of created CIS handles corresponding to the same order of [cid_id]. + """ + num_cis = len(cis_id) + + response = await self.send_command( + HCI_LE_Set_CIG_Parameters_Command( # type: ignore[call-arg] + cig_id=cig_id, + sdu_interval_c_to_p=sdu_interval[0], + sdu_interval_p_to_c=sdu_interval[1], + worst_case_sca=0x00, # 251-500 ppm + packing=0x00, # Sequential + framing=framing, + max_transport_latency_c_to_p=max_transport_latency[0], + max_transport_latency_p_to_c=max_transport_latency[1], + cis_id=cis_id, + max_sdu_c_to_p=[max_sdu[0]] * num_cis, + max_sdu_p_to_c=[max_sdu[1]] * num_cis, + phy_c_to_p=[HCI_LE_2M_PHY] * num_cis, + phy_p_to_c=[HCI_LE_2M_PHY] * num_cis, + rtn_c_to_p=[retransmission_number] * num_cis, + rtn_p_to_c=[retransmission_number] * num_cis, + ), + check_result=True, + ) + + # Ideally, we should manage CIG lifecycle, but they are not useful for Unicast + # Server, so here it only provides a basic functionality for testing. + cis_handles = response.return_parameters.connection_handle[:] + for id, cis_handle in zip(cis_id, cis_handles): + self._pending_cis[cis_handle] = (id, cig_id) + + return cis_handles + + # [LE only] + @experimental('Only for testing.') + async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink]: + for cis_handle, acl_handle in cis_acl_pairs: + acl_connection = self.lookup_connection(acl_handle) + assert acl_connection + cis_id, cig_id = self._pending_cis.pop(cis_handle) + self.cis_links[cis_handle] = CisLink( + device=self, + acl_connection=acl_connection, + handle=cis_handle, + cis_id=cis_id, + cig_id=cig_id, + ) + + result = await self.send_command( + HCI_LE_Create_CIS_Command( # type: ignore[call-arg] + cis_connection_handle=[p[0] for p in cis_acl_pairs], + acl_connection_handle=[p[1] for p in cis_acl_pairs], + ), + ) + if result.status != HCI_COMMAND_STATUS_PENDING: + logger.warning( + 'HCI_LE_Create_CIS_Command failed: ' + f'{HCI_Constant.error_name(result.status)}' + ) + raise HCI_StatusError(result) + + pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {} + for cis_handle, _ in cis_acl_pairs: + pending_cis_establishments[ + cis_handle + ] = asyncio.get_running_loop().create_future() + + with closing(EventWatcher()) as watcher: + + @watcher.on(self, 'cis_establishment') + def on_cis_establishment(cis_link: CisLink) -> None: + if pending_future := pending_cis_establishments.get( + cis_link.handle, None + ): + pending_future.set_result(cis_link) + + return await asyncio.gather(*pending_cis_establishments.values()) + + # [LE only] + @experimental('Only for testing.') + async def accept_cis_request(self, handle: int) -> CisLink: + result = await self.send_command( + HCI_LE_Accept_CIS_Request_Command( # type: ignore[call-arg] + connection_handle=handle + ), + ) + if result.status != HCI_COMMAND_STATUS_PENDING: + logger.warning( + 'HCI_LE_Accept_CIS_Request_Command failed: ' + f'{HCI_Constant.error_name(result.status)}' + ) + raise HCI_StatusError(result) + + pending_cis_establishment = asyncio.get_running_loop().create_future() + + with closing(EventWatcher()) as watcher: + + @watcher.on(self, 'cis_establishment') + def on_cis_establishment(cis_link: CisLink) -> None: + if cis_link.handle == handle: + pending_cis_establishment.set_result(cis_link) + + return await pending_cis_establishment + + # [LE only] + @experimental('Only for testing.') + async def reject_cis_request( + self, + handle: int, + reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, + ) -> None: + result = await self.send_command( + HCI_LE_Reject_CIS_Request_Command( # type: ignore[call-arg] + connection_handle=handle, reason=reason + ), + ) + if result.status != HCI_COMMAND_STATUS_PENDING: + logger.warning( + 'HCI_LE_Reject_CIS_Request_Command failed: ' + f'{HCI_Constant.error_name(result.status)}' + ) + raise HCI_StatusError(result) + @host_event_handler def on_flush(self): self.emit('flush') @@ -3041,30 +3260,35 @@ class Device(CompositeEventEmitter): ) @host_event_handler - @with_connection_from_handle - def on_disconnection(self, connection, reason): - logger.debug( - f'*** Disconnection: [0x{connection.handle:04X}] ' - f'{connection.peer_address} as {connection.role_name}, reason={reason}' - ) - connection.emit('disconnection', reason) + def on_disconnection(self, connection_handle: int, reason: int) -> None: + if connection := self.connections.pop(connection_handle, None): + logger.debug( + f'*** Disconnection: [0x{connection.handle:04X}] ' + f'{connection.peer_address} as {connection.role_name}, reason={reason}' + ) + connection.emit('disconnection', reason) - # Remove the connection from the map - del self.connections[connection.handle] + # Cleanup subsystems that maintain per-connection state + self.gatt_server.on_disconnection(connection) - # Cleanup subsystems that maintain per-connection state - self.gatt_server.on_disconnection(connection) - - # Restart advertising if auto-restart is enabled - if self.auto_restart_advertising: - logger.debug('restarting advertising') - self.abort_on( - 'flush', - self.start_advertising( - advertising_type=self.advertising_type, - own_address_type=self.advertising_own_address_type, - auto_restart=True, - ), + # Restart advertising if auto-restart is enabled + if self.auto_restart_advertising: + logger.debug('restarting advertising') + self.abort_on( + 'flush', + self.start_advertising( + advertising_type=self.advertising_type, # type: ignore[arg-type] + own_address_type=self.advertising_own_address_type, # type: ignore[arg-type] + auto_restart=True, + ), + ) + elif sco_link := self.sco_links.pop(connection_handle, None): + sco_link.emit('disconnection', reason) + elif cis_link := self.cis_links.pop(connection_handle, None): + cis_link.emit('disconnection', reason) + else: + logger.error( + f'*** Unknown disconnection handle=0x{connection_handle}, reason={reason} ***' ) @host_event_handler @@ -3343,6 +3567,98 @@ class Device(CompositeEventEmitter): connection.emit('remote_name_failure', error) self.emit('remote_name_failure', address, error) + # [Classic only] + @host_event_handler + @with_connection_from_address + def on_sco_connection( + self, acl_connection: Connection, sco_handle: int, link_type: int + ) -> None: + logger.debug( + f'*** SCO connected: {acl_connection.peer_address}, ' + f'sco_handle=[0x{sco_handle:04X}], ' + f'link_type=[0x{link_type:02X}] ***' + ) + self.sco_links[sco_handle] = ScoLink( + device=self, + acl_connection=acl_connection, + handle=sco_handle, + link_type=link_type, + ) + + # [Classic only] + @host_event_handler + @with_connection_from_address + def on_sco_connection_failure( + self, acl_connection: Connection, status: int + ) -> None: + logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***') + + # [Classic only] + @host_event_handler + def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: + if sco_link := self.sco_links.get(sco_handle, None): + sco_link.emit('pdu', packet) + + # [LE only] + @host_event_handler + @with_connection_from_handle + def on_cis_request( + self, + acl_connection: Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ) -> None: + logger.debug( + f'*** CIS Request ' + f'acl_handle=[0x{acl_connection.handle:04X}]{acl_connection.peer_address}, ' + f'cis_handle=[0x{cis_handle:04X}], ' + f'cig_id=[0x{cig_id:02X}], ' + f'cis_id=[0x{cis_id:02X}] ***' + ) + # LE_CIS_Established event doesn't provide info, so we must store them here. + self.cis_links[cis_handle] = CisLink( + device=self, + acl_connection=acl_connection, + handle=cis_handle, + cig_id=cig_id, + cis_id=cis_id, + ) + self.emit('cis_request', acl_connection, cis_handle, cig_id, cis_id) + + # [LE only] + @host_event_handler + def on_cis_establishment(self, cis_handle: int) -> None: + cis_link = self.cis_links[cis_handle] + cis_link.state = CisLink.State.ESTABLISHED + + assert cis_link.acl_connection + + logger.debug( + f'*** CIS Establishment ' + f'{cis_link.acl_connection.peer_address}, ' + f'cis_handle=[0x{cis_handle:04X}], ' + f'cig_id=[0x{cis_link.cig_id:02X}], ' + f'cis_id=[0x{cis_link.cis_id:02X}] ***' + ) + + cis_link.emit('establishment') + self.emit('cis_establishment', cis_link) + + # [LE only] + @host_event_handler + def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: + logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') + if cis_link := self.cis_links.pop(cis_handle, None): + cis_link.emit('establishment_failure') + self.emit('cis_establishment_failure', cis_handle, status) + + # [LE only] + @host_event_handler + def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: + if cis_link := self.cis_links.get(handle, None): + cis_link.emit('pdu', packet) + @host_event_handler @with_connection_from_handle def on_connection_encryption_change(self, connection, encryption): diff --git a/bumble/hci.py b/bumble/hci.py index 67fe457..f978644 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4451,7 +4451,10 @@ class HCI_LE_Accept_CIS_Request_Command(HCI_Command): # ----------------------------------------------------------------------------- @HCI_Command.command( - fields=[('connection_handle', 2)], + fields=[ + ('connection_handle', 2), + ('reason', {'size': 1, 'mapper': HCI_Constant.error_name}), + ], ) class HCI_LE_Reject_CIS_Request_Command(HCI_Command): ''' @@ -4459,6 +4462,7 @@ class HCI_LE_Reject_CIS_Request_Command(HCI_Command): ''' connection_handle: int + reason: int # ----------------------------------------------------------------------------- diff --git a/bumble/host.py b/bumble/host.py index a649eb6..b06ceba 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -32,8 +32,8 @@ from .hci import ( Address, HCI_ACL_DATA_PACKET, HCI_COMMAND_PACKET, - HCI_COMMAND_COMPLETE_EVENT, HCI_EVENT_PACKET, + HCI_ISO_DATA_PACKET, HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND, @@ -52,6 +52,7 @@ from .hci import ( HCI_Constant, HCI_Error, HCI_Event, + HCI_IsoDataPacket, HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, HCI_LE_Long_Term_Key_Request_Reply_Command, HCI_LE_Read_Buffer_Size_Command, @@ -75,7 +76,6 @@ from .core import ( BT_LE_TRANSPORT, ConnectionPHY, ConnectionParameters, - InvalidStateError, ) from .utils import AbortableEventEmitter from .transport.common import TransportLostError @@ -243,7 +243,7 @@ class Host(AbortableEventEmitter): # understand le_event_mask = bytes.fromhex('1F00000000000000') else: - le_event_mask = bytes.fromhex('FFFFF00000000000') + le_event_mask = bytes.fromhex('FFFFFFFF00000000') await self.send_command( HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask) @@ -495,6 +495,8 @@ class Host(AbortableEventEmitter): self.on_hci_acl_data_packet(cast(HCI_AclDataPacket, packet)) elif packet.hci_packet_type == HCI_SYNCHRONOUS_DATA_PACKET: self.on_hci_sco_data_packet(cast(HCI_SynchronousDataPacket, packet)) + elif packet.hci_packet_type == HCI_ISO_DATA_PACKET: + self.on_hci_iso_data_packet(cast(HCI_IsoDataPacket, packet)) else: logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') @@ -515,6 +517,10 @@ class Host(AbortableEventEmitter): # Experimental self.emit('sco_packet', packet.connection_handle, packet) + def on_hci_iso_data_packet(self, packet: HCI_IsoDataPacket) -> None: + # Experimental + self.emit('iso_packet', packet.connection_handle, packet) + def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: self.emit('l2cap_pdu', connection.handle, cid, pdu) @@ -715,6 +721,24 @@ class Host(AbortableEventEmitter): def on_hci_le_extended_advertising_report_event(self, event): self.on_hci_le_advertising_report_event(event) + def on_hci_le_cis_request_event(self, event): + self.emit( + 'cis_request', + event.acl_connection_handle, + event.cis_connection_handle, + event.cig_id, + event.cis_id, + ) + + def on_hci_le_cis_established_event(self, event): + # The remaining parameters are unused for now. + if event.status == HCI_SUCCESS: + self.emit('cis_establishment', event.connection_handle) + else: + self.emit( + 'cis_establishment_failure', event.connection_handle, event.status + ) + def on_hci_le_remote_connection_parameter_request_event(self, event): if event.connection_handle not in self.connections: logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle') diff --git a/examples/leaudio.json b/examples/leaudio.json new file mode 100644 index 0000000..4b6edfc --- /dev/null +++ b/examples/leaudio.json @@ -0,0 +1,5 @@ +{ + "name": "Bumble-LEA", + "keystore": "JsonKeyStore", + "advertising_interval": 100 +} diff --git a/examples/run_cig_setup.py b/examples/run_cig_setup.py new file mode 100644 index 0000000..a7f7260 --- /dev/null +++ b/examples/run_cig_setup.py @@ -0,0 +1,105 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import sys +import os +from bumble.device import ( + HCI_LE_Set_Extended_Advertising_Parameters_Command, + Device, + Connection, +) +from bumble.hci import OwnAddressType + +from bumble.transport import open_transport_or_link + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 3: + print( + 'Usage: run_cig_setup.py ' + ' ' + ) + print( + 'example: run_cig_setup.py device1.json' + 'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402' + ) + return + + print('<<< connecting to HCI...') + hci_transports = await asyncio.gather( + open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3]) + ) + print('<<< connected') + + devices = [ + Device.from_config_file_with_hci( + sys.argv[1], hci_transport.source, hci_transport.sink + ) + for hci_transport in hci_transports + ] + + devices[0].cis_enabled = True + devices[1].cis_enabled = True + + await asyncio.gather(*[device.power_on() for device in devices]) + await devices[0].start_extended_advertising( + advertising_properties=( + HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING + ), + own_address_type=OwnAddressType.PUBLIC, + ) + + connection = await devices[1].connect( + devices[0].public_address, own_address_type=OwnAddressType.PUBLIC + ) + + cid_ids = [2, 3] + cis_handles = await devices[1].setup_cig( + cig_id=1, + cis_id=cid_ids, + sdu_interval=(10000, 0), + framing=0, + max_sdu=(120, 0), + retransmission_number=13, + max_transport_latency=(100, 0), + ) + + def on_cis_request( + connection: Connection, cis_handle: int, _cig_id: int, _cis_id: int + ): + connection.abort_on('disconnection', devices[0].accept_cis_request(cis_handle)) + + devices[0].on('cis_request', on_cis_request) + + cis_links = await devices[1].create_cis( + [(cis, connection.handle) for cis in cis_handles] + ) + + for cis_link in cis_links: + await cis_link.disconnect() + + await asyncio.gather( + *[hci_transport.source.terminated for hci_transport in hci_transports] + ) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) From 40ae661ee569e7126eef3a7477935640f820fa12 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 30 Nov 2023 12:49:46 +0800 Subject: [PATCH 3/8] More SCO support and warnings and typo fix --- bumble/device.py | 11 ++++- bumble/hfp.py | 10 ++-- examples/run_cig_setup.py | 8 +-- examples/run_esco_connection.py | 87 +++++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 examples/run_esco_connection.py diff --git a/bumble/device.py b/bumble/device.py index 75caece..bda761b 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -3570,6 +3570,7 @@ class Device(CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address + @experimental('Only for testing.') def on_sco_connection( self, acl_connection: Connection, sco_handle: int, link_type: int ) -> None: @@ -3578,23 +3579,27 @@ class Device(CompositeEventEmitter): f'sco_handle=[0x{sco_handle:04X}], ' f'link_type=[0x{link_type:02X}] ***' ) - self.sco_links[sco_handle] = ScoLink( + sco_link = self.sco_links[sco_handle] = ScoLink( device=self, acl_connection=acl_connection, handle=sco_handle, link_type=link_type, ) + self.emit('sco_connection', sco_link) # [Classic only] @host_event_handler @with_connection_from_address + @experimental('Only for testing.') def on_sco_connection_failure( self, acl_connection: Connection, status: int ) -> None: logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***') + self.emit('sco_connection_failure') # [Classic only] @host_event_handler + @experimental('Only for testing') def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: if sco_link := self.sco_links.get(sco_handle, None): sco_link.emit('pdu', packet) @@ -3602,6 +3607,7 @@ class Device(CompositeEventEmitter): # [LE only] @host_event_handler @with_connection_from_handle + @experimental('Only for testing') def on_cis_request( self, acl_connection: Connection, @@ -3628,6 +3634,7 @@ class Device(CompositeEventEmitter): # [LE only] @host_event_handler + @experimental('Only for testing') def on_cis_establishment(self, cis_handle: int) -> None: cis_link = self.cis_links[cis_handle] cis_link.state = CisLink.State.ESTABLISHED @@ -3647,6 +3654,7 @@ class Device(CompositeEventEmitter): # [LE only] @host_event_handler + @experimental('Only for testing') def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') if cis_link := self.cis_links.pop(cis_handle, None): @@ -3655,6 +3663,7 @@ class Device(CompositeEventEmitter): # [LE only] @host_event_handler + @experimental('Only for testing') def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: if cis_link := self.cis_links.get(handle, None): cis_link.emit('pdu', packet) diff --git a/bumble/hfp.py b/bumble/hfp.py index 42683d5..a655b8f 100644 --- a/bumble/hfp.py +++ b/bumble/hfp.py @@ -850,10 +850,10 @@ class EscoParameters: # Common input_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = ( - HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT + HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.PCM ) output_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = ( - HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT + HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.PCM ) input_coded_data_size: int = 16 output_coded_data_size: int = 16 @@ -960,6 +960,8 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters( | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 ), + input_bandwidth=32000, + output_bandwidth=32000, retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY, ) @@ -974,10 +976,12 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters( | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 ), + input_bandwidth=32000, + output_bandwidth=32000, retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY, ) -ESCO_PERAMETERS = { +ESCO_PARAMETERS = { DefaultCodecParameters.SCO_CVSD_D0: _ESCO_PARAMETERS_CVSD_D0, DefaultCodecParameters.SCO_CVSD_D1: _ESCO_PARAMETERS_CVSD_D1, DefaultCodecParameters.ESCO_CVSD_S1: _ESCO_PARAMETERS_CVSD_S1, diff --git a/examples/run_cig_setup.py b/examples/run_cig_setup.py index a7f7260..ff12bfc 100644 --- a/examples/run_cig_setup.py +++ b/examples/run_cig_setup.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 Google LLC +# Copyright 2021-2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,13 @@ import logging import sys import os from bumble.device import ( - HCI_LE_Set_Extended_Advertising_Parameters_Command, Device, Connection, ) -from bumble.hci import OwnAddressType +from bumble.hci import ( + OwnAddressType, + HCI_LE_Set_Extended_Advertising_Parameters_Command, +) from bumble.transport import open_transport_or_link diff --git a/examples/run_esco_connection.py b/examples/run_esco_connection.py new file mode 100644 index 0000000..a136360 --- /dev/null +++ b/examples/run_esco_connection.py @@ -0,0 +1,87 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import dataclasses +import logging +import sys +import os +from bumble.core import BT_BR_EDR_TRANSPORT +from bumble.device import Device, ScoLink +from bumble.hci import HCI_Enhanced_Setup_Synchronous_Connection_Command +from bumble.hfp import DefaultCodecParameters, ESCO_PARAMETERS + +from bumble.transport import open_transport_or_link + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 3: + print( + 'Usage: run_esco_connection.py ' + ' ' + ) + print( + 'example: run_esco_connection.py classic1.json' + 'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402' + ) + return + + print('<<< connecting to HCI...') + hci_transports = await asyncio.gather( + open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3]) + ) + print('<<< connected') + + devices = [ + Device.from_config_file_with_hci( + sys.argv[1], hci_transport.source, hci_transport.sink + ) + for hci_transport in hci_transports + ] + + devices[0].classic_enabled = True + devices[1].classic_enabled = True + + await asyncio.gather(*[device.power_on() for device in devices]) + + connections = await asyncio.gather( + devices[0].accept(devices[1].public_address), + devices[1].connect(devices[0].public_address, transport=BT_BR_EDR_TRANSPORT), + ) + + def on_sco(sco_link: ScoLink): + connections[0].abort_on('disconnection', sco_link.disconnect()) + + devices[0].once('sco_connection', on_sco) + + await devices[0].send_command( + HCI_Enhanced_Setup_Synchronous_Connection_Command( + connection_handle=connections[0].handle, + **dataclasses.asdict(ESCO_PARAMETERS[DefaultCodecParameters.ESCO_CVSD_S3]) + # type: ignore[call-args] + ) + ) + + await asyncio.gather( + *[hci_transport.source.terminated for hci_transport in hci_transports] + ) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) From f27015d1b702f5fe6ac2c9ee58744554b6501c8f Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 30 Nov 2023 23:47:19 +0800 Subject: [PATCH 4/8] Refactor CodingFormat As CodingFormat is now used by HFP and LEA, and vendor specific codecs are introduced, this object needs to provide more information. --- .vscode/settings.json | 2 + bumble/hci.py | 77 ++++++++++++++++++++++----------- bumble/hfp.py | 58 ++++++++++++++----------- examples/run_esco_connection.py | 2 +- tests/hci_test.py | 15 +++++++ 5 files changed, 101 insertions(+), 53 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 4011e64..b564a38 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,6 +23,7 @@ "CONNECTIONLESS", "csip", "csrcs", + "CVSD", "datagram", "DATALINK", "delayreport", @@ -40,6 +41,7 @@ "libc", "libusb", "MITM", + "MSBC", "NDIS", "netsim", "NONBLOCK", diff --git a/bumble/hci.py b/bumble/hci.py index f978644..376a940 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import collections +import dataclasses import enum import functools import logging @@ -1382,6 +1383,45 @@ HCI_LE_SUPPORTED_FEATURES_NAMES = { STATUS_SPEC = {'size': 1, 'mapper': lambda x: HCI_Constant.status_name(x)} +class CodecID(enum.IntEnum): + # fmt: off + U_LOG = 0x00 + A_LOG = 0x01 + CVSD = 0x02 + TRANSPARENT = 0x03 + LINEAR_PCM = 0x04 + MSBC = 0x05 + LC3 = 0x06 + G729A = 0x07 + VENDOR_SPECIFIC = 0xFF + + +@dataclasses.dataclass(frozen=True) +class CodingFormat: + codec_id: CodecID + company_id: int = 0 + vendor_specific_codec_id: int = 0 + + @classmethod + def parse_from_bytes(cls, data: bytes, offset: int): + (codec_id, company_id, vendor_specific_codec_id) = struct.unpack_from( + ' bytes: + return struct.pack( + ' bytes: + return self.to_bytes() + + # ----------------------------------------------------------------------------- class HCI_Constant: @staticmethod @@ -1888,6 +1928,7 @@ Address.NIL = Address(b"\xff\xff\xff\xff\xff\xff", Address.PUBLIC_DEVICE_ADDRESS Address.ANY = Address(b"\x00\x00\x00\x00\x00\x00", Address.PUBLIC_DEVICE_ADDRESS) Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_ADDRESS) + # ----------------------------------------------------------------------------- class OwnAddressType: PUBLIC = 0 @@ -2445,14 +2486,14 @@ class HCI_IO_Capability_Request_Negative_Reply_Command(HCI_Command): ('connection_handle', 2), ('transmit_bandwidth', 4), ('receive_bandwidth', 4), - ('transmit_coding_format', 5), - ('receive_coding_format', 5), + ('transmit_coding_format', CodingFormat.parse_from_bytes), + ('receive_coding_format', CodingFormat.parse_from_bytes), ('transmit_codec_frame_size', 2), ('receive_codec_frame_size', 2), ('input_bandwidth', 4), ('output_bandwidth', 4), - ('input_coding_format', 5), - ('output_coding_format', 5), + ('input_coding_format', CodingFormat.parse_from_bytes), + ('output_coding_format', CodingFormat.parse_from_bytes), ('input_coded_data_size', 2), ('output_coded_data_size', 2), ('input_pcm_data_format', 1), @@ -2473,22 +2514,6 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_Command): See Bluetooth spec @ 7.1.45 Enhanced Setup Synchronous Connection Command ''' - class CodingFormat(enum.IntEnum): - U_LOG = 0x00 - A_LOG = 0x01 - CVSD = 0x02 - TRANSPARENT = 0x03 - PCM = 0x04 - MSBC = 0x05 - LC3 = 0x06 - G729A = 0x07 - - def to_bytes(self): - return self.value.to_bytes(5, 'little') - - def __bytes__(self): - return self.to_bytes() - class PcmDataFormat(enum.IntEnum): NA = 0x00 ONES_COMPLEMENT = 0x01 @@ -2525,14 +2550,14 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_Command): ('bd_addr', Address.parse_address), ('transmit_bandwidth', 4), ('receive_bandwidth', 4), - ('transmit_coding_format', 5), - ('receive_coding_format', 5), + ('transmit_coding_format', CodingFormat.parse_from_bytes), + ('receive_coding_format', CodingFormat.parse_from_bytes), ('transmit_codec_frame_size', 2), ('receive_codec_frame_size', 2), ('input_bandwidth', 4), ('output_bandwidth', 4), - ('input_coding_format', 5), - ('output_coding_format', 5), + ('input_coding_format', CodingFormat.parse_from_bytes), + ('output_coding_format', CodingFormat.parse_from_bytes), ('input_coded_data_size', 2), ('output_coded_data_size', 2), ('input_pcm_data_format', 1), @@ -4471,7 +4496,7 @@ class HCI_LE_Reject_CIS_Request_Command(HCI_Command): ('connection_handle', 2), ('data_path_direction', 1), ('data_path_id', 1), - ('codec_id', 5), + ('codec_id', CodingFormat.parse_from_bytes), ('controller_delay', 3), ('codec_configuration', '*'), ], @@ -4488,7 +4513,7 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command): connection_handle: int data_path_direction: int data_path_id: int - codec_id: int + codec_id: CodingFormat controller_delay: int codec_configuration: int diff --git a/bumble/hfp.py b/bumble/hfp.py index a655b8f..2079e32 100644 --- a/bumble/hfp.py +++ b/bumble/hfp.py @@ -22,7 +22,7 @@ import dataclasses import enum import traceback import warnings -from typing import Dict, List, Union, Set, TYPE_CHECKING +from typing import Dict, List, Union, Set, Any, TYPE_CHECKING from . import at from . import rfcomm @@ -35,7 +35,11 @@ from bumble.core import ( BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID, ) -from bumble.hci import HCI_Enhanced_Setup_Synchronous_Connection_Command +from bumble.hci import ( + HCI_Enhanced_Setup_Synchronous_Connection_Command, + CodingFormat, + CodecID, +) from bumble.sdp import ( DataElement, ServiceAttribute, @@ -66,6 +70,7 @@ class HfpProtocolError(ProtocolError): # Protocol Support # ----------------------------------------------------------------------------- + # ----------------------------------------------------------------------------- class HfpProtocol: dlc: rfcomm.DLC @@ -842,19 +847,15 @@ class DefaultCodecParameters(enum.IntEnum): @dataclasses.dataclass class EscoParameters: # Codec specific - transmit_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat - receive_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat + transmit_coding_format: CodingFormat + receive_coding_format: CodingFormat packet_type: HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType retransmission_effort: HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort max_latency: int # Common - input_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = ( - HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.PCM - ) - output_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = ( - HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.PCM - ) + input_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM) + output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM) input_coded_data_size: int = 16 output_coded_data_size: int = 16 input_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = ( @@ -880,26 +881,31 @@ class EscoParameters: transmit_codec_frame_size: int = 60 receive_codec_frame_size: int = 60 + def asdict(self) -> Dict[str, Any]: + # dataclasses.asdict() will recursively deep-copy the entire object, + # which is expensive and breaks CodingFormat object, so let it simply copy here. + return self.__dict__ + _ESCO_PARAMETERS_CVSD_D0 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0xFFFF, packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV1, retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION, ) _ESCO_PARAMETERS_CVSD_D1 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0xFFFF, packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV3, retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION, ) _ESCO_PARAMETERS_CVSD_S1 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0x0007, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 @@ -912,8 +918,8 @@ _ESCO_PARAMETERS_CVSD_S1 = EscoParameters( ) _ESCO_PARAMETERS_CVSD_S2 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0x0007, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 @@ -925,8 +931,8 @@ _ESCO_PARAMETERS_CVSD_S2 = EscoParameters( ) _ESCO_PARAMETERS_CVSD_S3 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0x000A, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 @@ -938,8 +944,8 @@ _ESCO_PARAMETERS_CVSD_S3 = EscoParameters( ) _ESCO_PARAMETERS_CVSD_S4 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD, + transmit_coding_format=CodingFormat(CodecID.CVSD), + receive_coding_format=CodingFormat(CodecID.CVSD), max_latency=0x000C, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 @@ -951,8 +957,8 @@ _ESCO_PARAMETERS_CVSD_S4 = EscoParameters( ) _ESCO_PARAMETERS_MSBC_T1 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC, + transmit_coding_format=CodingFormat(CodecID.MSBC), + receive_coding_format=CodingFormat(CodecID.MSBC), max_latency=0x0008, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 @@ -966,8 +972,8 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters( ) _ESCO_PARAMETERS_MSBC_T2 = EscoParameters( - transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC, - receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC, + transmit_coding_format=CodingFormat(CodecID.MSBC), + receive_coding_format=CodingFormat(CodecID.MSBC), max_latency=0x000D, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 diff --git a/examples/run_esco_connection.py b/examples/run_esco_connection.py index a136360..0ad34c4 100644 --- a/examples/run_esco_connection.py +++ b/examples/run_esco_connection.py @@ -72,7 +72,7 @@ async def main() -> None: await devices[0].send_command( HCI_Enhanced_Setup_Synchronous_Connection_Command( connection_handle=connections[0].handle, - **dataclasses.asdict(ESCO_PARAMETERS[DefaultCodecParameters.ESCO_CVSD_S3]) + **ESCO_PARAMETERS[DefaultCodecParameters.ESCO_CVSD_S3].asdict(), # type: ignore[call-args] ) ) diff --git a/tests/hci_test.py b/tests/hci_test.py index c648592..12f611f 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -24,6 +24,8 @@ from bumble.hci import ( HCI_RESET_COMMAND, HCI_SUCCESS, Address, + CodingFormat, + CodecID, HCI_Command, HCI_Command_Complete_Event, HCI_Command_Status_Event, @@ -442,6 +444,19 @@ def test_HCI_LE_Set_Extended_Advertising_Enable_Command(): basic_check(command) +# ----------------------------------------------------------------------------- +def test_HCI_LE_Setup_ISO_Data_Path_Command(): + command = HCI_Packet.from_bytes(bytes.fromhex('016e200d60000001030000000000000000')) + + assert command.connection_handle == 0x0060 + assert command.data_path_direction == 0x00 + assert command.data_path_id == 0x01 + assert command.codec_id == CodingFormat(CodecID.TRANSPARENT) + assert command.controller_delay == 0 + + basic_check(command) + + # ----------------------------------------------------------------------------- def test_address(): a = Address('C4:F2:17:1A:1D:BB') From 3fc71a0266b2bff72bea0f6915734b70d3eebb58 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 30 Nov 2023 23:58:30 +0800 Subject: [PATCH 5/8] Add variable-length bytes field --- bumble/hci.py | 15 +++++++++++++-- tests/hci_test.py | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/bumble/hci.py b/bumble/hci.py index 376a940..f652d23 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1517,6 +1517,12 @@ class HCI_Object: # The rest of the bytes field_value = data[offset:] return (field_value, len(field_value)) + if field_type == 'v': + # Variable-length bytes field, with 1-byte length at the beginning + field_length = data[offset] + offset += 1 + field_value = data[offset : offset + field_length] + return (field_value, field_length + 1) if field_type == 1: # 8-bit unsigned return (data[offset], 1) @@ -1621,6 +1627,11 @@ class HCI_Object: raise ValueError('value too large for *-typed field') else: field_bytes = bytes(field_value) + elif field_type == 'v': + # Variable-length bytes field, with 1-byte length at the beginning + field_bytes = bytes(field_bytes) + field_length = len(field_bytes) + field_bytes = bytes([field_length]) + field_bytes elif isinstance(field_value, (bytes, bytearray)) or hasattr( field_value, 'to_bytes' ): @@ -4498,7 +4509,7 @@ class HCI_LE_Reject_CIS_Request_Command(HCI_Command): ('data_path_id', 1), ('codec_id', CodingFormat.parse_from_bytes), ('controller_delay', 3), - ('codec_configuration', '*'), + ('codec_configuration', 'v'), ], return_parameters_fields=[ ('status', STATUS_SPEC), @@ -4515,7 +4526,7 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command): data_path_id: int codec_id: CodingFormat controller_delay: int - codec_configuration: int + codec_configuration: bytes # ----------------------------------------------------------------------------- diff --git a/tests/hci_test.py b/tests/hci_test.py index 12f611f..b6024ff 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -453,6 +453,7 @@ def test_HCI_LE_Setup_ISO_Data_Path_Command(): assert command.data_path_id == 0x01 assert command.codec_id == CodingFormat(CodecID.TRANSPARENT) assert command.controller_delay == 0 + assert command.codec_configuration == b'' basic_check(command) From 10a3833893f49bd7e650514b919d73a1d503b621 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Sat, 2 Dec 2023 19:10:05 +0800 Subject: [PATCH 6/8] Remove # type: ignore[call-arg] in HCI_Command builders --- bumble/device.py | 112 +++++++++++++++----------------- bumble/hci.py | 7 +- bumble/l2cap.py | 2 +- bumble/smp.py | 2 +- examples/run_esco_connection.py | 1 - 5 files changed, 62 insertions(+), 62 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index bda761b..9ae502e 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1432,7 +1432,7 @@ class Device(CompositeEventEmitter): await self.host.reset() # Try to get the public address from the controller - response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg] + response = await self.send_command(HCI_Read_BD_ADDR_Command()) if response.return_parameters.status == HCI_SUCCESS: logger.debug( color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow') @@ -1455,7 +1455,7 @@ class Device(CompositeEventEmitter): HCI_Write_LE_Host_Support_Command( le_supported_host=int(self.le_enabled), simultaneous_le_host=int(self.le_simultaneous_enabled), - ) # type: ignore[call-arg] + ) ) if self.le_enabled: @@ -1465,7 +1465,7 @@ class Device(CompositeEventEmitter): if self.host.supports_command(HCI_LE_RAND_COMMAND): # Get 8 random bytes response = await self.send_command( - HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg] + HCI_LE_Rand_Command(), check_result=True ) # Ensure the address bytes can be a static random address @@ -1486,7 +1486,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Random_Address_Command( random_address=self.random_address - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1499,12 +1499,12 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Address_Resolution_Enable_Command( address_resolution_enable=1 - ) # type: ignore[call-arg] + ) ) if self.cis_enabled: await self.send_command( - HCI_LE_Set_Host_Feature_Command( # type: ignore[call-arg] + HCI_LE_Set_Host_Feature_Command( bit_number=( HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE ), @@ -1514,20 +1514,20 @@ class Device(CompositeEventEmitter): if self.classic_enabled: await self.send_command( - HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg] + HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) ) await self.send_command( - HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg] + HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) ) await self.send_command( HCI_Write_Simple_Pairing_Mode_Command( simple_pairing_mode=int(self.classic_ssp_enabled) - ) # type: ignore[call-arg] + ) ) await self.send_command( HCI_Write_Secure_Connections_Host_Support_Command( secure_connections_host_support=int(self.classic_sc_enabled) - ) # type: ignore[call-arg] + ) ) await self.set_connectable(self.connectable) await self.set_discoverable(self.discoverable) @@ -1551,7 +1551,7 @@ class Device(CompositeEventEmitter): self.address_resolver = smp.AddressResolver(resolving_keys) if self.address_resolution_offload: - await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg] + await self.send_command(HCI_LE_Clear_Resolving_List_Command()) for irk, address in resolving_keys: await self.send_command( @@ -1560,7 +1560,7 @@ class Device(CompositeEventEmitter): peer_identity_address=address, peer_irk=irk, local_irk=self.irk, - ) # type: ignore[call-arg] + ) ) def supports_le_feature(self, feature): @@ -1595,7 +1595,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Advertising_Data_Command( advertising_data=self.advertising_data - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1604,7 +1604,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Scan_Response_Data_Command( scan_response_data=self.scan_response_data - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1630,13 +1630,13 @@ class Device(CompositeEventEmitter): peer_address=peer_address, advertising_channel_map=7, advertising_filter_policy=0, - ), # type: ignore[call-arg] + ), check_result=True, ) # Enable advertising await self.send_command( - HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg] + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), check_result=True, ) @@ -1649,7 +1649,7 @@ class Device(CompositeEventEmitter): # Disable advertising if self.advertising: await self.send_command( - HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg] + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), check_result=True, ) @@ -1716,7 +1716,7 @@ class Device(CompositeEventEmitter): secondary_advertising_phy=1, # LE 1M advertising_sid=0, scan_request_notification_enable=0, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1728,7 +1728,7 @@ class Device(CompositeEventEmitter): operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, fragment_preference=0x01, # Should not fragment advertising_data=advertising_data, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1740,7 +1740,7 @@ class Device(CompositeEventEmitter): operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, fragment_preference=0x01, # Should not fragment scan_response_data=scan_response, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1752,7 +1752,7 @@ class Device(CompositeEventEmitter): HCI_LE_Set_Advertising_Set_Random_Address_Command( advertising_handle=adv_handle, random_address=self.random_address, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1763,13 +1763,13 @@ class Device(CompositeEventEmitter): advertising_handles=[adv_handle], durations=[0], # Forever max_extended_advertising_events=[0], # Infinite - ), # type: ignore[call-arg] + ), check_result=True, ) except HCI_Error as error: # When any step fails, cleanup the advertising handle. await self.send_command( - HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), # type: ignore[call-arg] + HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), check_result=False, ) raise error @@ -1791,12 +1791,12 @@ class Device(CompositeEventEmitter): advertising_handles=[adv_handle], durations=[0], max_extended_advertising_events=[0], - ), # type: ignore[call-arg] + ), check_result=True, ) # Remove advertising set await self.send_command( - HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), # type: ignore[call-arg] + HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), check_result=True, ) self.extended_advertising_handles.remove(adv_handle) @@ -1864,7 +1864,7 @@ class Device(CompositeEventEmitter): scan_types=[scan_type] * scanning_phy_count, scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count, scan_windows=[int(scan_window / 0.625)] * scanning_phy_count, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1875,7 +1875,7 @@ class Device(CompositeEventEmitter): filter_duplicates=1 if filter_duplicates else 0, duration=0, # TODO allow other values period=0, # TODO allow other values - ), # type: ignore[call-arg] + ), check_result=True, ) else: @@ -1893,7 +1893,7 @@ class Device(CompositeEventEmitter): le_scan_window=int(scan_window / 0.625), own_address_type=own_address_type, scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1901,7 +1901,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Scan_Enable_Command( le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0 - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -1914,12 +1914,12 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Extended_Scan_Enable_Command( enable=0, filter_duplicates=0, duration=0, period=0 - ), # type: ignore[call-arg] + ), check_result=True, ) else: await self.send_command( - HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg] + HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), check_result=True, ) @@ -1939,7 +1939,7 @@ class Device(CompositeEventEmitter): async def start_discovery(self, auto_restart: bool = True) -> None: await self.send_command( - HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg] + HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), check_result=True, ) @@ -1948,7 +1948,7 @@ class Device(CompositeEventEmitter): lap=HCI_GENERAL_INQUIRY_LAP, inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH, num_responses=0, # Unlimited number of responses. - ) # type: ignore[call-arg] + ) ) if response.status != HCI_Command_Status_Event.PENDING: self.discovering = False @@ -1959,7 +1959,7 @@ class Device(CompositeEventEmitter): async def stop_discovery(self) -> None: if self.discovering: - await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg] + await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) self.auto_restart_inquiry = True self.discovering = False @@ -2007,7 +2007,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_Write_Extended_Inquiry_Response_Command( fec_required=0, extended_inquiry_response=self.inquiry_response - ), # type: ignore[call-arg] + ), check_result=True, ) await self.set_scan_enable( @@ -2196,7 +2196,7 @@ class Device(CompositeEventEmitter): supervision_timeouts=supervision_timeouts, min_ce_lengths=min_ce_lengths, max_ce_lengths=max_ce_lengths, - ) # type: ignore[call-arg] + ) ) else: if HCI_LE_1M_PHY not in connection_parameters_preferences: @@ -2225,7 +2225,7 @@ class Device(CompositeEventEmitter): supervision_timeout=int(prefs.supervision_timeout / 10), min_ce_length=int(prefs.min_ce_length / 0.625), max_ce_length=int(prefs.max_ce_length / 0.625), - ) # type: ignore[call-arg] + ) ) else: # Save pending connection @@ -2242,7 +2242,7 @@ class Device(CompositeEventEmitter): clock_offset=0x0000, allow_role_switch=0x01, reserved=0, - ) # type: ignore[call-arg] + ) ) if result.status != HCI_Command_Status_Event.PENDING: @@ -2261,10 +2261,10 @@ class Device(CompositeEventEmitter): ) except asyncio.TimeoutError: if transport == BT_LE_TRANSPORT: - await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg] + await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) else: await self.send_command( - HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg] + HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) ) try: @@ -2378,7 +2378,7 @@ class Device(CompositeEventEmitter): try: # Accept connection request await self.send_command( - HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg] + HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) ) # Wait for connection complete @@ -2445,7 +2445,7 @@ class Device(CompositeEventEmitter): # Request a disconnection result = await self.send_command( - HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) # type: ignore[call-arg] + HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) ) try: @@ -2476,7 +2476,7 @@ class Device(CompositeEventEmitter): connection_handle=connection.handle, tx_octets=tx_octets, tx_time=tx_time, - ), # type: ignore[call-arg] + ), check_result=True, ) @@ -2522,7 +2522,7 @@ class Device(CompositeEventEmitter): supervision_timeout=supervision_timeout, min_ce_length=min_ce_length, max_ce_length=max_ce_length, - ) # type: ignore[call-arg] + ) ) if result.status != HCI_Command_Status_Event.PENDING: raise HCI_StatusError(result) @@ -2850,7 +2850,7 @@ class Device(CompositeEventEmitter): try: result = await self.send_command( - HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role) # type: ignore[call-arg] + HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role) ) if result.status != HCI_COMMAND_STATUS_PENDING: logger.warning( @@ -2892,7 +2892,7 @@ class Device(CompositeEventEmitter): page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2, reserved=0, clock_offset=0, # TODO investigate non-0 values - ) # type: ignore[call-arg] + ) ) if result.status != HCI_COMMAND_STATUS_PENDING: @@ -2938,7 +2938,7 @@ class Device(CompositeEventEmitter): num_cis = len(cis_id) response = await self.send_command( - HCI_LE_Set_CIG_Parameters_Command( # type: ignore[call-arg] + HCI_LE_Set_CIG_Parameters_Command( cig_id=cig_id, sdu_interval_c_to_p=sdu_interval[0], sdu_interval_p_to_c=sdu_interval[1], @@ -2982,7 +2982,7 @@ class Device(CompositeEventEmitter): ) result = await self.send_command( - HCI_LE_Create_CIS_Command( # type: ignore[call-arg] + HCI_LE_Create_CIS_Command( cis_connection_handle=[p[0] for p in cis_acl_pairs], acl_connection_handle=[p[1] for p in cis_acl_pairs], ), @@ -3015,9 +3015,7 @@ class Device(CompositeEventEmitter): @experimental('Only for testing.') async def accept_cis_request(self, handle: int) -> CisLink: result = await self.send_command( - HCI_LE_Accept_CIS_Request_Command( # type: ignore[call-arg] - connection_handle=handle - ), + HCI_LE_Accept_CIS_Request_Command(connection_handle=handle), ) if result.status != HCI_COMMAND_STATUS_PENDING: logger.warning( @@ -3045,9 +3043,7 @@ class Device(CompositeEventEmitter): reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, ) -> None: result = await self.send_command( - HCI_LE_Reject_CIS_Request_Command( # type: ignore[call-arg] - connection_handle=handle, reason=reason - ), + HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason), ) if result.status != HCI_COMMAND_STATUS_PENDING: logger.warning( @@ -3439,7 +3435,7 @@ class Device(CompositeEventEmitter): try: if await connection.abort_on('disconnection', method()): await self.host.send_command( - HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg] + HCI_User_Confirmation_Request_Reply_Command( bd_addr=connection.peer_address ) ) @@ -3448,7 +3444,7 @@ class Device(CompositeEventEmitter): logger.warning(f'exception while confirming: {error}') await self.host.send_command( - HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg] + HCI_User_Confirmation_Request_Negative_Reply_Command( bd_addr=connection.peer_address ) ) @@ -3469,7 +3465,7 @@ class Device(CompositeEventEmitter): ) if number is not None: await self.host.send_command( - HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg] + HCI_User_Passkey_Request_Reply_Command( bd_addr=connection.peer_address, numeric_value=number ) ) @@ -3478,7 +3474,7 @@ class Device(CompositeEventEmitter): logger.warning(f'exception while asking for pass-key: {error}') await self.host.send_command( - HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg] + HCI_User_Passkey_Request_Negative_Reply_Command( bd_addr=connection.peer_address ) ) diff --git a/bumble/hci.py b/bumble/hci.py index f652d23..829933f 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -2018,6 +2018,7 @@ class HCI_Command(HCI_Packet): hci_packet_type = HCI_COMMAND_PACKET command_names: Dict[int, str] = {} command_classes: Dict[int, Type[HCI_Command]] = {} + op_code: int @staticmethod def command(fields=(), return_parameters_fields=()): @@ -2103,7 +2104,11 @@ class HCI_Command(HCI_Packet): return_parameters.fields = cls.return_parameters_fields return return_parameters - def __init__(self, op_code, parameters=None, **kwargs): + def __init__(self, op_code=-1, parameters=None, **kwargs): + # Since the legacy implementation relies on an __init__ injector, typing always + # complains that positional argument op_code is not passed, so here sets a + # default value to allow building derived HCI_Command without op_code. + assert op_code != -1 super().__init__(HCI_Command.command_name(op_code)) if (fields := getattr(self, 'fields', None)) and kwargs: HCI_Object.init_from_fields(self, fields, kwargs) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 4ccdeab..ce3385d 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -1926,7 +1926,7 @@ class ChannelManager: supervision_timeout=request.timeout, min_ce_length=0, max_ce_length=0, - ) # type: ignore[call-arg] + ) ) else: self.send_control_frame( diff --git a/bumble/smp.py b/bumble/smp.py index 25dd46b..f8879c6 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1090,7 +1090,7 @@ class Session: # We can now encrypt the connection with the short term key, so that we can # distribute the long term and/or other keys over an encrypted connection self.manager.device.host.send_command_sync( - HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg] + HCI_LE_Enable_Encryption_Command( connection_handle=self.connection.handle, random_number=bytes(8), encrypted_diversifier=0, diff --git a/examples/run_esco_connection.py b/examples/run_esco_connection.py index 0ad34c4..6f3e800 100644 --- a/examples/run_esco_connection.py +++ b/examples/run_esco_connection.py @@ -73,7 +73,6 @@ async def main() -> None: HCI_Enhanced_Setup_Synchronous_Connection_Command( connection_handle=connections[0].handle, **ESCO_PARAMETERS[DefaultCodecParameters.ESCO_CVSD_S3].asdict(), - # type: ignore[call-args] ) ) From dff14e1258043fb71c42fadb990301bc251e1b7e Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 29 Nov 2023 21:41:14 +0800 Subject: [PATCH 7/8] Add Published Audio Capabilities Service --- bumble/profiles/bap.py | 496 +++++++++++++++++++++++++++++++++ examples/run_unicast_server.py | 134 +++++++++ tests/bap_test.py | 151 ++++++++++ 3 files changed, 781 insertions(+) create mode 100644 bumble/profiles/bap.py create mode 100644 examples/run_unicast_server.py create mode 100644 tests/bap_test.py diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py new file mode 100644 index 0000000..76015d5 --- /dev/null +++ b/bumble/profiles/bap.py @@ -0,0 +1,496 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import enum +import struct +import functools +from typing import Optional, List, Union + +from bumble import hci +from bumble import gatt +from bumble import gatt_client + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + + +class AudioLocation(enum.IntFlag): + '''Bluetooth Assigned Numbers, Section 6.12.1 - Audio Location''' + + # fmt: off + NOT_ALLOWED = 0x00000000 + FRONT_LEFT = 0x00000001 + FRONT_RIGHT = 0x00000002 + FRONT_CENTER = 0x00000004 + LOW_FREQUENCY_EFFECTS_1 = 0x00000008 + BACK_LEFT = 0x00000010 + BACK_RIGHT = 0x00000020 + FRONT_LEFT_OF_CENTER = 0x00000040 + FRONT_RIGHT_OF_CENTER = 0x00000080 + BACK_CENTER = 0x00000100 + LOW_FREQUENCY_EFFECTS_2 = 0x00000200 + SIDE_LEFT = 0x00000400 + SIDE_RIGHT = 0x00000800 + TOP_FRONT_LEFT = 0x00001000 + TOP_FRONT_RIGHT = 0x00002000 + TOP_FRONT_CENTER = 0x00004000 + TOP_CENTER = 0x00008000 + TOP_BACK_LEFT = 0x00010000 + TOP_BACK_RIGHT = 0x00020000 + TOP_SIDE_LEFT = 0x00040000 + TOP_SIDE_RIGHT = 0x00080000 + TOP_BACK_CENTER = 0x00100000 + BOTTOM_FRONT_CENTER = 0x00200000 + BOTTOM_FRONT_LEFT = 0x00400000 + BOTTOM_FRONT_RIGHT = 0x00800000 + FRONT_LEFT_WIDE = 0x01000000 + FRONT_RIGHT_WIDE = 0x02000000 + LEFT_SURROUND = 0x04000000 + RIGHT_SURROUND = 0x08000000 + + +class AudioInputType(enum.IntEnum): + '''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type''' + + # fmt: off + UNSPECIFIED = 0x00 + BLUETOOTH = 0x01 + MICROPHONE = 0x02 + ANALOG = 0x03 + DIGITAL = 0x04 + RADIO = 0x05 + STREAMING = 0x06 + AMBIENT = 0x07 + + +class ContextType(enum.IntFlag): + '''Bluetooth Assigned Numbers, Section 6.12.3 - Context Type''' + + # fmt: off + PROHIBITED = 0x0000 + CONVERSATIONAL = 0x0002 + MEDIA = 0x0004 + GAME = 0x0008 + INSTRUCTIONAL = 0x0010 + VOICE_ASSISTANTS = 0x0020 + LIVE = 0x0040 + SOUND_EFFECTS = 0x0080 + NOTIFICATIONS = 0x0100 + RINGTONE = 0x0200 + ALERTS = 0x0400 + EMERGENCY_ALARM = 0x0800 + + +class SamplingFrequency(enum.IntEnum): + '''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency''' + + # fmt: off + FREQ_8000 = 0x01 + FREQ_11025 = 0x02 + FREQ_16000 = 0x03 + FREQ_22050 = 0x04 + FREQ_24000 = 0x05 + FREQ_32000 = 0x06 + FREQ_44100 = 0x07 + FREQ_48000 = 0x08 + FREQ_88200 = 0x09 + FREQ_96000 = 0x0A + FREQ_176400 = 0x0B + FREQ_192000 = 0x0C + FREQ_384000 = 0x0D + # fmt: on + + @classmethod + def from_hz(cls, frequency: int) -> SamplingFrequency: + return { + 8000: SamplingFrequency.FREQ_8000, + 11025: SamplingFrequency.FREQ_11025, + 16000: SamplingFrequency.FREQ_16000, + 22050: SamplingFrequency.FREQ_22050, + 24000: SamplingFrequency.FREQ_24000, + 32000: SamplingFrequency.FREQ_32000, + 44100: SamplingFrequency.FREQ_44100, + 48000: SamplingFrequency.FREQ_48000, + 88200: SamplingFrequency.FREQ_88200, + 96000: SamplingFrequency.FREQ_96000, + 176400: SamplingFrequency.FREQ_176400, + 192000: SamplingFrequency.FREQ_192000, + 384000: SamplingFrequency.FREQ_384000, + }[frequency] + + @property + def hz(self) -> int: + return { + SamplingFrequency.FREQ_8000: 8000, + SamplingFrequency.FREQ_11025: 11025, + SamplingFrequency.FREQ_16000: 16000, + SamplingFrequency.FREQ_22050: 22050, + SamplingFrequency.FREQ_24000: 24000, + SamplingFrequency.FREQ_32000: 32000, + SamplingFrequency.FREQ_44100: 44100, + SamplingFrequency.FREQ_48000: 48000, + SamplingFrequency.FREQ_88200: 88200, + SamplingFrequency.FREQ_96000: 96000, + SamplingFrequency.FREQ_176400: 176400, + SamplingFrequency.FREQ_192000: 192000, + SamplingFrequency.FREQ_384000: 384000, + }[self] + + +class SupportedSamplingFrequency(enum.IntFlag): + '''Bluetooth Assigned Numbers, Section 6.12.4.1 - Sample Frequency''' + + # fmt: off + FREQ_8000 = 1 << (SamplingFrequency.FREQ_8000 - 1) + FREQ_11025 = 1 << (SamplingFrequency.FREQ_11025 - 1) + FREQ_16000 = 1 << (SamplingFrequency.FREQ_16000 - 1) + FREQ_22050 = 1 << (SamplingFrequency.FREQ_22050 - 1) + FREQ_24000 = 1 << (SamplingFrequency.FREQ_24000 - 1) + FREQ_32000 = 1 << (SamplingFrequency.FREQ_32000 - 1) + FREQ_44100 = 1 << (SamplingFrequency.FREQ_44100 - 1) + FREQ_48000 = 1 << (SamplingFrequency.FREQ_48000 - 1) + FREQ_88200 = 1 << (SamplingFrequency.FREQ_88200 - 1) + FREQ_96000 = 1 << (SamplingFrequency.FREQ_96000 - 1) + FREQ_176400 = 1 << (SamplingFrequency.FREQ_176400 - 1) + FREQ_192000 = 1 << (SamplingFrequency.FREQ_192000 - 1) + FREQ_384000 = 1 << (SamplingFrequency.FREQ_384000 - 1) + # fmt: on + + @classmethod + def from_hz(cls, frequencies: Sequence[int]) -> SupportedSamplingFrequency: + MAPPING = { + 8000: SupportedSamplingFrequency.FREQ_8000, + 11025: SupportedSamplingFrequency.FREQ_11025, + 16000: SupportedSamplingFrequency.FREQ_16000, + 22050: SupportedSamplingFrequency.FREQ_22050, + 24000: SupportedSamplingFrequency.FREQ_24000, + 32000: SupportedSamplingFrequency.FREQ_32000, + 44100: SupportedSamplingFrequency.FREQ_44100, + 48000: SupportedSamplingFrequency.FREQ_48000, + 88200: SupportedSamplingFrequency.FREQ_88200, + 96000: SupportedSamplingFrequency.FREQ_96000, + 176400: SupportedSamplingFrequency.FREQ_176400, + 192000: SupportedSamplingFrequency.FREQ_192000, + 384000: SupportedSamplingFrequency.FREQ_384000, + } + + return functools.reduce( + lambda x, y: x | MAPPING[y], + frequencies, + cls(0), + ) + + +class FrameDuration(enum.IntEnum): + '''Bluetooth Assigned Numbers, Section 6.12.5.2 - Frame Duration''' + + # fmt: off + DURATION_7500_US = 0x00 + DURATION_10000_US = 0x01 + + +class SupportedFrameDuration(enum.IntFlag): + '''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration''' + + # fmt: off + DURATION_7500_US_SUPPORTED = 0b0001 + DURATION_10000_US_SUPPORTED = 0b0010 + DURATION_7500_US_PREFERRED = 0b0001 + DURATION_10000_US_PREFERRED = 0b0010 + + +# ----------------------------------------------------------------------------- +# Utils +# ----------------------------------------------------------------------------- + + +def bits_to_channel_counts(data: int) -> List[int]: + pos = 0 + counts = [] + while data != 0: + # Bit 0 = count 1 + # Bit 1 = count 2, and so on + pos += 1 + if data & 1: + counts.append(pos) + data >>= 1 + return counts + + +def channel_counts_to_bits(counts: Sequence[int]) -> int: + return sum(set([1 << (count - 1) for count in counts])) + + +# ----------------------------------------------------------------------------- +# Structures +# ----------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CodecSpecificCapabilities: + '''See: + * Bluetooth Assigned Numbers, 6.12.4 - Codec Specific Capabilities LTV Structures + * Basic Audio Profile, 4.3.1 - Codec_Specific_Capabilities LTV requirements + ''' + + class Type(enum.IntEnum): + # fmt: off + SAMPLING_FREQUENCY = 0x01 + FRAME_DURATION = 0x02 + AUDIO_CHANNEL_COUNT = 0x03 + OCTETS_PER_FRAME = 0x04 + CODEC_FRAMES_PER_SDU = 0x05 + + supported_sampling_frequencies: SupportedSamplingFrequency + supported_frame_durations: SupportedFrameDuration + supported_audio_channel_counts: Sequence[int] + min_octets_per_codec_frame: int + max_octets_per_codec_frame: int + supported_max_codec_frames_per_sdu: int + + @classmethod + def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: + offset = 0 + # Allowed default values. + supported_audio_channel_counts = [1] + supported_max_codec_frames_per_sdu = 1 + while offset < len(data): + length, type = struct.unpack_from('BB', data, offset) + offset += 2 + value = int.from_bytes(data[offset : offset + length - 1], 'little') + offset += length - 1 + + if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY: + supported_sampling_frequencies = SupportedSamplingFrequency(value) + elif type == CodecSpecificCapabilities.Type.FRAME_DURATION: + supported_frame_durations = SupportedFrameDuration(value) + elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: + supported_audio_channel_counts = bits_to_channel_counts(value) + elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: + min_octets_per_sample = value & 0xFFFF + max_octets_per_sample = value >> 16 + elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU: + supported_max_codec_frames_per_sdu = value + + # It is expected here that if some fields are missing, an error should be raised. + return CodecSpecificCapabilities( + supported_sampling_frequencies=supported_sampling_frequencies, + supported_frame_durations=supported_frame_durations, + supported_audio_channel_counts=supported_audio_channel_counts, + min_octets_per_codec_frame=min_octets_per_sample, + max_octets_per_codec_frame=max_octets_per_sample, + supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu, + ) + + def __bytes__(self) -> bytes: + return struct.pack( + ' PacRecord: + offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0) + codec_specific_capabilities_size = data[offset] + + offset += 1 + codec_specific_capabilities_bytes = data[ + offset : offset + codec_specific_capabilities_size + ] + offset += codec_specific_capabilities_size + metadata_size = data[offset] + metadata = data[offset : offset + metadata_size] + + codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] + if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: + codec_specific_capabilities = codec_specific_capabilities_bytes + else: + codec_specific_capabilities = CodecSpecificCapabilities.from_bytes( + codec_specific_capabilities_bytes + ) + + return PacRecord( + coding_format=coding_format, + codec_specific_capabilities=codec_specific_capabilities, + metadata=metadata, + ) + + def __bytes__(self) -> bytes: + capabilities_bytes = bytes(self.codec_specific_capabilities) + return ( + bytes(self.coding_format) + + bytes([len(capabilities_bytes)]) + + capabilities_bytes + + bytes([len(self.metadata)]) + + self.metadata + ) + + +# ----------------------------------------------------------------------------- +# Server +# ----------------------------------------------------------------------------- +class PublishedAudioCapabilitiesService(gatt.TemplateService): + UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE + + sink_pac: Optional[gatt.Characteristic] + sink_audio_locations: Optional[gatt.Characteristic] + source_pac: Optional[gatt.Characteristic] + source_audio_locations: Optional[gatt.Characteristic] + available_audio_contexts: gatt.Characteristic + supported_audio_contexts: gatt.Characteristic + + def __init__( + self, + supported_source_context: ContextType, + supported_sink_context: ContextType, + available_source_context: ContextType, + available_sink_context: ContextType, + sink_pac: Sequence[PacRecord] = [], + sink_audio_locations: Optional[AudioLocation] = None, + source_pac: Sequence[PacRecord] = [], + source_audio_locations: Optional[AudioLocation] = None, + ) -> None: + characteristics = [] + + self.supported_audio_contexts = gatt.Characteristic( + uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.READ, + permissions=gatt.Characteristic.Permissions.READABLE, + value=struct.pack(' None: + if len(sys.argv) < 3: + print('Usage: run_cig_setup.py ' '') + return + + print('<<< connecting to HCI...') + async with await open_transport_or_link(sys.argv[2]) as hci_transport: + print('<<< connected') + + device = Device.from_config_file_with_hci( + sys.argv[1], hci_transport.source, hci_transport.sink + ) + device.cis_enabled = True + + await device.power_on() + + device.add_service( + PublishedAudioCapabilitiesService( + supported_source_context=ContextType.PROHIBITED, + available_source_context=ContextType.PROHIBITED, + supported_sink_context=ContextType.MEDIA, + available_sink_context=ContextType.MEDIA, + sink_audio_locations=( + AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT + ), + sink_pac=[ + # Codec Capability Setting 16_2 + PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_16000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1], + min_octets_per_codec_frame=40, + max_octets_per_codec_frame=40, + supported_max_codec_frames_per_sdu=1, + ), + ), + # Codec Capability Setting 24_2 + PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_24000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1], + min_octets_per_codec_frame=60, + max_octets_per_codec_frame=60, + supported_max_codec_frames_per_sdu=1, + ), + ), + ], + ) + ) + + advertising_data = bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes('Bumble LE Audio', 'utf-8'), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(PublishedAudioCapabilitiesService.UUID), + ), + ] + ) + ) + + await device.start_extended_advertising( + advertising_properties=( + HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING + ), + own_address_type=OwnAddressType.RANDOM, + advertising_data=advertising_data, + ) + + await hci_transport.source.terminated + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/tests/bap_test.py b/tests/bap_test.py new file mode 100644 index 0000000..01fc568 --- /dev/null +++ b/tests/bap_test.py @@ -0,0 +1,151 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import os +import pytest +import logging + +from bumble import device +from bumble.hci import CodecID, CodingFormat +from bumble.profiles.bap import ( + AudioLocation, + SupportedFrameDuration, + SupportedSamplingFrequency, + CodecSpecificCapabilities, + ContextType, + PacRecord, + PublishedAudioCapabilitiesService, + PublishedAudioCapabilitiesServiceProxy, +) +from .test_utils import TwoDevices + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def test_codec_specific_capabilities() -> None: + SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000 + FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED + AUDIO_CHANNEL_COUNTS = [1] + cap = CodecSpecificCapabilities( + supported_sampling_frequencies=SAMPLE_FREQUENCY, + supported_frame_durations=FRAME_SURATION, + supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, + min_octets_per_codec_frame=40, + max_octets_per_codec_frame=40, + supported_max_codec_frames_per_sdu=1, + ) + assert CodecSpecificCapabilities.from_bytes(bytes(cap)) == cap + + +# ----------------------------------------------------------------------------- +def test_pac_record() -> None: + SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000 + FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED + AUDIO_CHANNEL_COUNTS = [1] + cap = CodecSpecificCapabilities( + supported_sampling_frequencies=SAMPLE_FREQUENCY, + supported_frame_durations=FRAME_SURATION, + supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, + min_octets_per_codec_frame=40, + max_octets_per_codec_frame=40, + supported_max_codec_frames_per_sdu=1, + ) + + pac_record = PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=cap, + metadata=b'', + ) + assert PacRecord.from_bytes(bytes(pac_record)) == pac_record + + +# ----------------------------------------------------------------------------- +def test_vendor_specific_pac_record() -> None: + # Vendor-Specific codec, Google, ID=0xFFFF. No capabilities and metadata. + RAW_DATA = bytes.fromhex('ffe000ffff0000') + assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_pacs(): + devices = TwoDevices() + devices[0].add_service( + PublishedAudioCapabilitiesService( + supported_sink_context=ContextType.MEDIA, + available_sink_context=ContextType.MEDIA, + supported_source_context=0, + available_source_context=0, + sink_pac=[ + # Codec Capability Setting 16_2 + PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_16000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1], + min_octets_per_codec_frame=40, + max_octets_per_codec_frame=40, + supported_max_codec_frames_per_sdu=1, + ), + ), + # Codec Capability Setting 24_2 + PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_24000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1], + min_octets_per_codec_frame=60, + max_octets_per_codec_frame=60, + supported_max_codec_frames_per_sdu=1, + ), + ), + ], + sink_audio_locations=AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT, + ) + ) + + await devices.setup_connection() + peer = device.Peer(devices.connections[1]) + pacs_client = await peer.discover_service_and_create_proxy( + PublishedAudioCapabilitiesServiceProxy + ) + + +# ----------------------------------------------------------------------------- +async def run(): + await test_pacs() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run()) From dc97be5b3572c03856817f5cc495cbefb3ffb442 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Sat, 2 Dec 2023 23:29:00 +0800 Subject: [PATCH 8/8] Fix typo --- bumble/hci.py | 2 +- tests/hci_test.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/bumble/hci.py b/bumble/hci.py index f652d23..f209249 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1629,7 +1629,7 @@ class HCI_Object: field_bytes = bytes(field_value) elif field_type == 'v': # Variable-length bytes field, with 1-byte length at the beginning - field_bytes = bytes(field_bytes) + field_bytes = bytes(field_value) field_length = len(field_bytes) field_bytes = bytes([field_length]) + field_bytes elif isinstance(field_value, (bytes, bytearray)) or hasattr( diff --git a/tests/hci_test.py b/tests/hci_test.py index b6024ff..5607350 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -53,6 +53,7 @@ from bumble.hci import ( HCI_LE_Set_Random_Address_Command, HCI_LE_Set_Scan_Enable_Command, HCI_LE_Set_Scan_Parameters_Command, + HCI_LE_Setup_ISO_Data_Path_Command, HCI_Number_Of_Completed_Packets_Event, HCI_Packet, HCI_PIN_Code_Request_Reply_Command, @@ -455,6 +456,14 @@ def test_HCI_LE_Setup_ISO_Data_Path_Command(): assert command.controller_delay == 0 assert command.codec_configuration == b'' + command = HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=0x0060, + data_path_direction=0x00, + data_path_id=0x01, + codec_id=CodingFormat(CodecID.TRANSPARENT), + controller_delay=0x00, + codec_configuration=b'', + ) basic_check(command)