From 72ac75a98da7661abce7dbb8fb69e538b3190d7d Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 6 Dec 2023 21:02:30 +0800 Subject: [PATCH 01/16] Add advertiser classes and handle adv set terminated events * Convert hci.OwnAddressType to enum * Add LegacyAdvertiser and ExtendedAdvertiser classes * Rename start/stop_advertising() => start/stop_legacy_advertising() * Handle HCI_Advertising_Set_Terminated * Properly restart advertisement on disconnection --- bumble/device.py | 217 +++++++++++++++++++++++++++++++++++++---------- bumble/hci.py | 43 +++++----- bumble/host.py | 8 ++ 3 files changed, 203 insertions(+), 65 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index 9ae502e4..369040f2 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -437,6 +437,34 @@ class AdvertisingType(IntEnum): ) +# ----------------------------------------------------------------------------- +@dataclass +class LegacyAdvertiser: + advertising_type: AdvertisingType + own_address_type: OwnAddressType + auto_restart: bool + advertising_data: Optional[bytes] + scan_response_data: Optional[bytes] + + +# ----------------------------------------------------------------------------- +@dataclass +class ExtendedAdvertiser(CompositeEventEmitter): + device: Device + handle: int + advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties + own_address_type: OwnAddressType + auto_restart: bool + advertising_data: Optional[bytes] + scan_response_data: Optional[bytes] + + def __post_init__(self) -> None: + super().__init__() + + async def stop(self) -> None: + await self.device.stop_extended_advertising(self.handle) + + # ----------------------------------------------------------------------------- class LePhyOptions: # Coded PHY preference @@ -658,6 +686,9 @@ class Connection(CompositeEventEmitter): gatt_client: gatt_client.Client pairing_peer_io_capability: Optional[int] pairing_peer_authentication_requirements: Optional[int] + advertiser_after_disconnection: Union[ + LegacyAdvertiser, ExtendedAdvertiser, None + ] = None @composite_listener class Listener: @@ -1063,7 +1094,8 @@ class Device(CompositeEventEmitter): ] advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] config: DeviceConfiguration - extended_advertising_handles: Set[int] + legacy_advertiser: Optional[LegacyAdvertiser] + extended_advertisers: Dict[int, ExtendedAdvertiser] sco_links: Dict[int, ScoLink] cis_links: Dict[int, CisLink] _pending_cis: Dict[int, Tuple[int, int]] @@ -1141,10 +1173,7 @@ class Device(CompositeEventEmitter): self._host = None self.powered_on = False - self.advertising = False - self.advertising_type = None self.auto_restart_inquiry = True - self.auto_restart_advertising = False self.command_timeout = 10 # seconds self.gatt_server = gatt_server.Server(self) self.sdp_server = sdp.Server(self) @@ -1168,10 +1197,10 @@ class Device(CompositeEventEmitter): self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY - self.extended_advertising_handles = set() + self.legacy_advertiser = None + self.extended_advertisers = {} # Own address type cache - self.advertising_own_address_type = None self.connect_own_address_type = None # Use the initial config or a default @@ -1579,6 +1608,7 @@ class Device(CompositeEventEmitter): return self.host.supports_le_feature(feature_map[phy]) + @deprecated("Please use start_legacy_advertising.") async def start_advertising( self, advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, @@ -1586,15 +1616,49 @@ class Device(CompositeEventEmitter): own_address_type: int = OwnAddressType.RANDOM, auto_restart: bool = False, ) -> None: + await self.start_legacy_advertising( + advertising_type=advertising_type, + target=target, + own_address_type=OwnAddressType(own_address_type), + auto_restart=auto_restart, + ) + + async def start_legacy_advertising( + self, + advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, + target: Optional[Address] = None, + own_address_type: OwnAddressType = OwnAddressType.RANDOM, + auto_restart: bool = False, + advertising_data: Optional[bytes] = None, + scan_response_data: Optional[bytes] = None, + ) -> LegacyAdvertiser: + """Starts an legacy advertisement. + + Args: + advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command. + target: Directed advertising target. Directed type should be set in advertising_type arg. + own_address_type: own address type to use in the advertising. + auto_restart: whether the advertisement will be restarted after disconnection. + scan_response_data: raw scan response. + advertising_data: raw advertising data. + + Returns: + LegacyAdvertiser object containing the metadata of advertisement. + """ + if self.extended_advertisers: + logger.warning( + 'Trying to start Legacy and Extended Advertising at the same time!' + ) + # If we're advertising, stop first - if self.advertising: + if self.legacy_advertiser: await self.stop_advertising() # Set/update the advertising data if the advertising type allows it if advertising_type.has_data: await self.send_command( HCI_LE_Set_Advertising_Data_Command( - advertising_data=self.advertising_data + advertising_data=advertising_data or self.advertising_data or b'' ), check_result=True, ) @@ -1603,7 +1667,9 @@ class Device(CompositeEventEmitter): if advertising_type.is_scannable: await self.send_command( HCI_LE_Set_Scan_Response_Data_Command( - scan_response_data=self.scan_response_data + scan_response_data=scan_response_data + or self.scan_response_data + or b'' ), check_result=True, ) @@ -1640,45 +1706,56 @@ class Device(CompositeEventEmitter): check_result=True, ) - self.advertising_type = advertising_type - self.advertising_own_address_type = own_address_type - self.advertising = True - self.auto_restart_advertising = auto_restart + self.legacy_advertiser = LegacyAdvertiser( + advertising_type=advertising_type, + own_address_type=own_address_type, + auto_restart=auto_restart, + advertising_data=advertising_data, + scan_response_data=scan_response_data, + ) + return self.legacy_advertiser + @deprecated("Please use stop_legacy_advertising.") async def stop_advertising(self) -> None: + await self.stop_legacy_advertising() + + async def stop_legacy_advertising(self) -> None: # Disable advertising - if self.advertising: + if self.legacy_advertiser: await self.send_command( HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), check_result=True, ) - self.advertising_type = None - self.advertising_own_address_type = None - self.advertising = False - self.auto_restart_advertising = False + self.legacy_advertiser = None @experimental('Extended Advertising is still experimental - Might be changed soon.') async def start_extended_advertising( self, advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING, target: Address = Address.ANY, - own_address_type: int = OwnAddressType.RANDOM, - scan_response: Optional[bytes] = None, + own_address_type: OwnAddressType = OwnAddressType.RANDOM, + auto_restart: bool = True, advertising_data: Optional[bytes] = None, - ) -> int: + scan_response_data: Optional[bytes] = None, + ) -> ExtendedAdvertiser: """Starts an extended advertising set. Args: advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command target: Directed advertising target. Directed property should be set in advertising_properties arg. own_address_type: own address type to use in the advertising. - scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent. + auto_restart: whether the advertisement will be restarted after disconnection. advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent. + scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent. Returns: - Handle of the new advertising set. + ExtendedAdvertiser object containing the metadata of advertisement. """ + if self.legacy_advertiser: + logger.warning( + 'Trying to start Legacy and Extended Advertising at the same time!' + ) adv_handle = -1 # Find a free handle @@ -1686,7 +1763,7 @@ class Device(CompositeEventEmitter): DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE, DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1, ): - if i not in self.extended_advertising_handles: + if i not in self.extended_advertisers.keys(): adv_handle = i break @@ -1733,13 +1810,13 @@ class Device(CompositeEventEmitter): ) # Set the scan response if present - if scan_response is not None: + if scan_response_data is not None: await self.send_command( HCI_LE_Set_Extended_Scan_Response_Data_Command( advertising_handle=adv_handle, operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, fragment_preference=0x01, # Should not fragment - scan_response_data=scan_response, + scan_response_data=scan_response_data, ), check_result=True, ) @@ -1774,8 +1851,16 @@ class Device(CompositeEventEmitter): ) raise error - self.extended_advertising_handles.add(adv_handle) - return adv_handle + advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser( + device=self, + handle=adv_handle, + advertising_properties=advertising_properties, + own_address_type=own_address_type, + auto_restart=auto_restart, + advertising_data=advertising_data, + scan_response_data=scan_response_data, + ) + return advertiser @experimental('Extended Advertising is still experimental - Might be changed soon.') async def stop_extended_advertising(self, adv_handle: int) -> None: @@ -1799,11 +1884,11 @@ class Device(CompositeEventEmitter): HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), check_result=True, ) - self.extended_advertising_handles.remove(adv_handle) + del self.extended_advertisers[adv_handle] @property def is_advertising(self): - return self.advertising + return self.legacy_advertiser or self.extended_advertisers async def start_scanning( self, @@ -3144,13 +3229,17 @@ class Device(CompositeEventEmitter): # Guess which own address type is used for this connection. # This logic is somewhat correct but may need to be improved # when multiple advertising are run simultaneously. + advertiser = None if self.connect_own_address_type is not None: own_address_type = self.connect_own_address_type + elif self.legacy_advertiser: + own_address_type = self.legacy_advertiser.own_address_type + # Store advertiser for restarting - it's only required for legacy, since + # extended advertisement produces HCI_Advertising_Set_Terminated. + advertiser = self.legacy_advertiser else: - own_address_type = self.advertising_own_address_type - - # We are no longer advertising - self.advertising = False + # For extended advertisement, determining own address type later. + own_address_type = OwnAddressType.RANDOM if own_address_type in ( OwnAddressType.PUBLIC, @@ -3172,6 +3261,7 @@ class Device(CompositeEventEmitter): connection_parameters, ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY), ) + connection.advertiser_after_disconnection = advertiser self.connections[connection_handle] = connection # If supported, read which PHY we're connected with before @@ -3203,10 +3293,10 @@ class Device(CompositeEventEmitter): # For directed advertising, this means a timeout if ( transport == BT_LE_TRANSPORT - and self.advertising - and self.advertising_type.is_directed + and self.legacy_advertiser + and self.legacy_advertiser.advertising_type.is_directed ): - self.advertising = False + self.legacy_advertiser = None # Notify listeners error = core.ConnectionError( @@ -3268,16 +3358,28 @@ class Device(CompositeEventEmitter): self.gatt_server.on_disconnection(connection) # Restart advertising if auto-restart is enabled - if self.auto_restart_advertising: + if advertiser := connection.advertiser_after_disconnection: logger.debug('restarting advertising') - self.abort_on( - 'flush', - self.start_advertising( - advertising_type=self.advertising_type, # type: ignore[arg-type] - own_address_type=self.advertising_own_address_type, # type: ignore[arg-type] - auto_restart=True, - ), - ) + if isinstance(advertiser, LegacyAdvertiser): + self.abort_on( + 'flush', + self.start_legacy_advertising( + advertising_type=advertiser.advertising_type, + own_address_type=advertiser.own_address_type, + auto_restart=True, + ), + ) + elif isinstance(advertiser, ExtendedAdvertiser): + self.abort_on( + 'flush', + self.start_extended_advertising( + advertising_properties=advertiser.advertising_properties, + own_address_type=advertiser.own_address_type, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + auto_restart=True, + ), + ) elif sco_link := self.sco_links.pop(connection_handle, None): sco_link.emit('disconnection', reason) elif cis_link := self.cis_links.pop(connection_handle, None): @@ -3600,6 +3702,29 @@ class Device(CompositeEventEmitter): if sco_link := self.sco_links.get(sco_handle, None): sco_link.emit('pdu', packet) + # [LE only] + @host_event_handler + @experimental('Only for testing') + def on_advertising_set_termination( + self, + status: int, + advertising_handle: int, + connection_handle: int, + ) -> None: + if status == HCI_SUCCESS: + connection = self.lookup_connection(connection_handle) + if advertiser := self.extended_advertisers.pop(advertising_handle, None): + if connection: + connection.advertiser_after_disconnection = advertiser + if advertiser.own_address_type in ( + OwnAddressType.PUBLIC, + OwnAddressType.RESOLVABLE_OR_PUBLIC, + ): + connection.self_address = self.public_address + else: + connection.self_address = self.random_address + advertiser.emit('termination', status) + # [LE only] @host_event_handler @with_connection_from_handle diff --git a/bumble/hci.py b/bumble/hci.py index a28246ab..936b6816 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1963,25 +1963,15 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_ # ----------------------------------------------------------------------------- -class OwnAddressType: +class OwnAddressType(enum.IntEnum): PUBLIC = 0 RANDOM = 1 RESOLVABLE_OR_PUBLIC = 2 RESOLVABLE_OR_RANDOM = 3 - TYPE_NAMES = { - PUBLIC: 'PUBLIC', - RANDOM: 'RANDOM', - RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC', - RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM', - } - - @staticmethod - def type_name(type_id): - return name_or_number(OwnAddressType.TYPE_NAMES, type_id) - - # pylint: disable-next=unnecessary-lambda - TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)} + @classmethod + def type_spec(cls): + return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name} # ----------------------------------------------------------------------------- @@ -3374,7 +3364,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command): ), }, ), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), ('advertising_channel_map', 1), @@ -3467,7 +3457,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command): ('le_scan_type', 1), ('le_scan_interval', 2), ('le_scan_window', 2), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('scanning_filter_policy', 1), ] ) @@ -3506,7 +3496,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command): ('initiator_filter_policy', 1), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('connection_interval_min', 2), ('connection_interval_max', 2), ('max_latency', 2), @@ -3913,7 +3903,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command): ), }, ), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), ('advertising_filter_policy', 1), @@ -4309,7 +4299,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command): ('initiator_filter_policy:', self.initiator_filter_policy), ( 'own_address_type: ', - OwnAddressType.type_name(self.own_address_type), + OwnAddressType(self.own_address_type).name, ), ( 'peer_address_type: ', @@ -5190,6 +5180,21 @@ HCI_LE_Meta_Event.subevent_classes[ ] = HCI_LE_Extended_Advertising_Report_Event +# ----------------------------------------------------------------------------- +@HCI_LE_Meta_Event.event( + [ + ('status', 1), + ('advertising_handle', 1), + ('connection_handle', 2), + ('number_completed_extended_advertising_events', 1), + ] +) +class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event + ''' + + # ----------------------------------------------------------------------------- @HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)]) class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event): diff --git a/bumble/host.py b/bumble/host.py index b06ceba4..3ae2280b 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -721,6 +721,14 @@ class Host(AbortableEventEmitter): def on_hci_le_extended_advertising_report_event(self, event): self.on_hci_le_advertising_report_event(event) + def on_hci_le_advertising_set_terminated_event(self, event): + self.emit( + 'advertising_set_termination', + event.status, + event.advertising_handle, + event.connection_handle, + ) + def on_hci_le_cis_request_event(self, event): self.emit( 'cis_request', From ff6528d2bf38fe39aec683be33cfab1e2c263771 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 8 Dec 2023 00:03:21 +0800 Subject: [PATCH 02/16] Add Advertising unit tests --- bumble/device.py | 15 +++- tests/device_test.py | 175 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 186 insertions(+), 4 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index 369040f2..f0f4ee18 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -440,12 +440,16 @@ class AdvertisingType(IntEnum): # ----------------------------------------------------------------------------- @dataclass class LegacyAdvertiser: + device: Device advertising_type: AdvertisingType own_address_type: OwnAddressType auto_restart: bool advertising_data: Optional[bytes] scan_response_data: Optional[bytes] + async def stop(self) -> None: + await self.device.stop_legacy_advertising() + # ----------------------------------------------------------------------------- @dataclass @@ -1707,6 +1711,7 @@ class Device(CompositeEventEmitter): ) self.legacy_advertiser = LegacyAdvertiser( + device=self, advertising_type=advertising_type, own_address_type=own_address_type, auto_restart=auto_restart, @@ -1763,7 +1768,7 @@ class Device(CompositeEventEmitter): DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE, DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1, ): - if i not in self.extended_advertisers.keys(): + if i not in self.extended_advertisers: adv_handle = i break @@ -3236,7 +3241,8 @@ class Device(CompositeEventEmitter): own_address_type = self.legacy_advertiser.own_address_type # Store advertiser for restarting - it's only required for legacy, since # extended advertisement produces HCI_Advertising_Set_Terminated. - advertiser = self.legacy_advertiser + if self.legacy_advertiser.auto_restart: + advertiser = self.legacy_advertiser else: # For extended advertisement, determining own address type later. own_address_type = OwnAddressType.RANDOM @@ -3366,6 +3372,8 @@ class Device(CompositeEventEmitter): self.start_legacy_advertising( advertising_type=advertiser.advertising_type, own_address_type=advertiser.own_address_type, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, auto_restart=True, ), ) @@ -3715,7 +3723,8 @@ class Device(CompositeEventEmitter): connection = self.lookup_connection(connection_handle) if advertiser := self.extended_advertisers.pop(advertising_handle, None): if connection: - connection.advertiser_after_disconnection = advertiser + if advertiser.auto_restart: + connection.advertiser_after_disconnection = advertiser if advertiser.own_address_type in ( OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC, diff --git a/tests/device_test.py b/tests/device_test.py index 1bcd0d08..d51431f5 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -20,8 +20,14 @@ import logging import os from types import LambdaType import pytest +from unittest import mock -from bumble.core import BT_BR_EDR_TRANSPORT +from bumble.core import ( + BT_BR_EDR_TRANSPORT, + BT_LE_TRANSPORT, + BT_PERIPHERAL_ROLE, + ConnectionParameters, +) from bumble.device import Connection, Device from bumble.host import Host from bumble.hci import ( @@ -30,6 +36,7 @@ from bumble.hci import ( HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS, Address, + OwnAddressType, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, @@ -232,6 +239,172 @@ async def test_flush(): pass +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_legacy_advertising(): + device = Device(host=mock.AsyncMock(Host)) + + # Start advertising + advertiser = await device.start_legacy_advertising() + assert device.legacy_advertiser + + # Stop advertising + await advertiser.stop() + assert not device.legacy_advertiser + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'own_address_type,', + (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), +) +@pytest.mark.asyncio +async def test_legacy_advertising_connection(own_address_type): + device = Device(host=mock.AsyncMock(Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + + # Start advertising + advertiser = await device.start_legacy_advertising() + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + + if own_address_type == OwnAddressType.PUBLIC: + assert device.lookup_connection(0x0001).self_address == device.public_address + else: + assert device.lookup_connection(0x0001).self_address == device.random_address + + # For unknown reason, read_phy() in on_connection() would be killed at the end of + # test, so we force scheduling here to avoid an warning. + await asyncio.sleep(0.0001) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'auto_restart,', + (True, False), +) +@pytest.mark.asyncio +async def test_legacy_advertising_disconnection(auto_restart): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_legacy_advertising(auto_restart=auto_restart) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + + device.start_legacy_advertising = mock.AsyncMock() + + device.on_disconnection(0x0001, 0) + + if auto_restart: + device.start_legacy_advertising.assert_called_with( + advertising_type=advertiser.advertising_type, + own_address_type=advertiser.own_address_type, + auto_restart=advertiser.auto_restart, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + ) + else: + device.start_legacy_advertising.assert_not_called() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_extended_advertising(): + device = Device(host=mock.AsyncMock(Host)) + + # Start advertising + advertiser = await device.start_extended_advertising() + assert device.extended_advertisers + + # Stop advertising + await advertiser.stop() + assert not device.extended_advertisers + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'own_address_type,', + (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), +) +@pytest.mark.asyncio +async def test_extended_advertising_connection(own_address_type): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_extended_advertising( + own_address_type=own_address_type + ) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + device.on_advertising_set_termination( + HCI_SUCCESS, + advertiser.handle, + 0x0001, + ) + + if own_address_type == OwnAddressType.PUBLIC: + assert device.lookup_connection(0x0001).self_address == device.public_address + else: + assert device.lookup_connection(0x0001).self_address == device.random_address + + # For unknown reason, read_phy() in on_connection() would be killed at the end of + # test, so we force scheduling here to avoid an warning. + await asyncio.sleep(0.0001) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'auto_restart,', + (True, False), +) +@pytest.mark.asyncio +async def test_extended_advertising_disconnection(auto_restart): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_extended_advertising(auto_restart=auto_restart) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + device.on_advertising_set_termination( + HCI_SUCCESS, + advertiser.handle, + 0x0001, + ) + + device.start_extended_advertising = mock.AsyncMock() + + device.on_disconnection(0x0001, 0) + + if auto_restart: + device.start_extended_advertising.assert_called_with( + advertising_properties=advertiser.advertising_properties, + own_address_type=advertiser.own_address_type, + auto_restart=advertiser.auto_restart, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + ) + else: + device.start_extended_advertising.assert_not_called() + + # ----------------------------------------------------------------------------- def test_gatt_services_with_gas(): device = Device(host=Host(None, None)) From 3575f9030ee54c27541726b5135198f5417c1487 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 30 Nov 2023 14:31:58 +0800 Subject: [PATCH 03/16] Add Audio Stream Control Service --- bumble/gatt_server.py | 2 +- bumble/hci.py | 13 + bumble/profiles/bap.py | 604 ++++++++++++++++++++++++++++++++- examples/leaudio.json | 1 + examples/run_unicast_server.py | 7 + 5 files changed, 625 insertions(+), 2 deletions(-) diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index cdf1b5e5..eca11ce4 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -961,7 +961,7 @@ class Server(EventEmitter): try: attribute.write_value(connection, request.attribute_value) except Exception as error: - logger.warning(f'!!! ignoring exception: {error}') + logger.exception(f'!!! ignoring exception: {error}') def on_att_handle_value_confirmation(self, connection, _confirmation): ''' diff --git a/bumble/hci.py b/bumble/hci.py index 936b6816..8d5f9cd9 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -728,6 +728,19 @@ HCI_LE_PHY_TYPE_TO_BIT = { HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT } + +class Phy(enum.IntEnum): + LE_1M = 0x01 + LE_2M = 0x02 + LE_CODED = 0x03 + + +class PhyBit(enum.IntFlag): + LE_1M = 0b00000001 + LE_2M = 0b00000010 + LE_CODED = 0b00000100 + + # Connection Parameters HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25 HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25 diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 76015d52..a1cae1b7 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -23,8 +23,11 @@ import dataclasses import enum import struct import functools -from typing import Optional, List, Union +import logging +from typing import Optional, List, Union, Type, Dict, Any, Tuple, cast +from bumble import colors +from bumble import device from bumble import hci from bumble import gatt from bumble import gatt_client @@ -220,6 +223,231 @@ class SupportedFrameDuration(enum.IntFlag): DURATION_10000_US_PREFERRED = 0b0010 +# ----------------------------------------------------------------------------- +# ASE Operations +# ----------------------------------------------------------------------------- + + +class ASE_Operation: + ''' + See Audio Stream Control Service - 5 ASE Control operations. + ''' + + classes: Dict[int, Type[ASE_Operation]] = {} + op_code: int + name: str + fields: Optional[Sequence[Any]] = None + ase_id: List[int] + + class Opcode(enum.IntEnum): + # fmt: off + CONFIG_CODEC = 0x01 + CONFIG_QOS = 0x02 + ENABLE = 0x03 + RECEIVER_START_READY = 0x04 + DISABLE = 0x05 + RECEIVER_STOP_READY = 0x06 + UPDATE_METADATA = 0x07 + RELEASE = 0x08 + + @staticmethod + def from_bytes(pdu: bytes) -> ASE_Operation: + op_code = pdu[0] + + cls = ASE_Operation.classes.get(op_code) + if cls is None: + instance = ASE_Operation(pdu) + instance.name = ASE_Operation.Opcode(op_code).name + instance.op_code = op_code + return instance + self = cls.__new__(cls) + ASE_Operation.__init__(self, pdu) + if self.fields is not None: + self.init_from_bytes(pdu, 1) + return self + + @staticmethod + def subclass(fields): + def inner(cls: Type[ASE_Operation]): + try: + operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] + cls.name = operation.name + cls.op_code = operation + except: + raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') + cls.fields = fields + + # Register a factory for this class + ASE_Operation.classes[cls.op_code] = cls + + return cls + + return inner + + def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: + if self.fields is not None and kwargs: + hci.HCI_Object.init_from_fields(self, self.fields, kwargs) + if pdu is None: + pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( + kwargs, self.fields + ) + self.pdu = pdu + + def init_from_bytes(self, pdu: bytes, offset: int): + return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) + + def __bytes__(self) -> bytes: + return self.pdu + + def __str__(self) -> str: + result = f'{colors.color(self.name, "yellow")} ' + if fields := getattr(self, 'fields', None): + result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') + else: + if len(self.pdu) > 1: + result += f': {self.pdu.hex()}' + return result + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('target_latency', 1), + ('target_phy', 1), + ('codec_id', hci.CodingFormat.parse_from_bytes), + ('codec_specific_configuration', 'v'), + ], + ] +) +class ASE_Config_Codec(ASE_Operation): + ''' + See Audio Stream Control Service 5.1 - Config Codec Operation + ''' + + target_latency: List[int] + target_phy: List[int] + codec_id: List[hci.CodingFormat] + codec_specific_configuration: List[bytes] + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('cig_id', 1), + ('cis_id', 1), + ('sdu_interval', 3), + ('framing', 1), + ('phy', 1), + ('max_sdu', 2), + ('retransmission_number', 1), + ('max_transport_latency', 2), + ('presentation_delay', 3), + ], + ] +) +class ASE_Config_QOS(ASE_Operation): + ''' + See Audio Stream Control Service 5.2 - Config Qos Operation + ''' + + cig_id: List[int] + cis_id: List[int] + sdu_interval: List[int] + framing: List[int] + phy: List[int] + max_sdu: List[int] + retransmission_number: List[int] + max_transport_latency: List[int] + presentation_delay: List[int] + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Enable(ASE_Operation): + ''' + See Audio Stream Control Service 5.3 - Enable Operation + ''' + + metadata: bytes + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Start_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.4 - Receiver Start Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Disable(ASE_Operation): + ''' + See Audio Stream Control Service 5.5 - Disable Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Stop_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Update_Metadata(ASE_Operation): + ''' + See Audio Stream Control Service 5.7 - Update Metadata Operation + ''' + + metadata: List[bytes] + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Release(ASE_Operation): + ''' + See Audio Stream Control Service 5.8 - Release Operation + ''' + + +class AseResponseCode(enum.IntEnum): + # fmt: off + SUCCESS = 0x00 + UNSUPPORTED_OPCODE = 0x01 + INVALID_LENGTH = 0x02 + INVALID_ASE_ID = 0x03 + INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 + INVALID_ASE_DIRECTION = 0x05 + UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 + UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 + REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 + INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 + UNSUPPORTED_METADATA = 0x0A + REJECTED_METADATA = 0x0B + INVALID_METADATA = 0x0C + INSUFFICIENT_RESOURCES = 0x0D + UNSPECIFIED_ERROR = 0x0E + + +class AseReasonCode(enum.IntEnum): + # fmt: off + NONE = 0x00 + CODEC_ID = 0x01 + CODEC_SPECIFIC_CONFIGURATION = 0x02 + SDU_INTERVAL = 0x03 + FRAMING = 0x04 + PHY = 0x05 + MAXIMUM_SDU_SIZE = 0x06 + RETRANSMISSION_NUMBER = 0x07 + MAX_TRANSPORT_LATENCY = 0x08 + PRESENTATION_DELAY = 0x09 + INVALID_ASE_CIS_MAPPING = 0x0A + + +class AudioRole(enum.Enum): + SINK = enum.auto() + SOURCE = enum.auto() + + # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- @@ -452,6 +680,380 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService): super().__init__(characteristics) +class AseStateMachine(gatt.Characteristic): + class State(enum.IntEnum): + # fmt: off + IDLE = 0x00 + CODEC_CONFIGURED = 0x01 + QOS_CONFIGURED = 0x02 + ENABLING = 0x03 + STREAMING = 0x04 + DISABLING = 0x05 + RELEASING = 0x06 + + # Additional parameters in CODEC_CONFIGURED State + preferred_framing = 0 # Unframed PDU supported + preferred_phy = 0 + preferred_retransmission_number = 13 + preferred_max_transport_latency = 100 + supported_presentation_delay_min = 0 + supported_presentation_delay_max = 0 + preferred_presentation_delay_min = 0 + preferred_presentation_delay_max = 0 + codec_id = hci.CodingFormat(hci.CodecID.LC3) + # TODO: Parse this + codec_specific_configuration = b'' + + # Additional parameters in QOS_CONFIGURED State + cig_id = 0 + cis_id = 0 + sdu_interval = 0 + framing = 0 + phy = 0 + max_sdu = 0 + retransmission_number = 0 + max_transport_latency = 0 + presentation_delay = 0 + + # Additional parameters in ENABLING, STREAMING, DISABLING State + # TODO: Parse this + metadata = b'' + + def __init__( + self, + role: AudioRole, + ase_id: int, + service: AudioStreamControlService, + ) -> None: + self.service = service + self.ase_id = ase_id + self.state = AseStateMachine.State.IDLE + self.role = role + + uuid = ( + gatt.GATT_SINK_ASE_CHARACTERISTIC + if role == AudioRole.SINK + else gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + super().__init__( + uuid=uuid, + properties=gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.READABLE, + value=gatt.CharacteristicValue(read=self.on_read), + ) + + self.service.device.on('cis_request', self.on_cis_request) + self.service.device.on('cis_establishment', self.on_cis_establishment) + + def on_cis_request( + self, + acl_connection: device.Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ) -> None: + if cis_id == self.cis_id and self.state == self.State.ENABLING: + acl_connection.abort_on( + 'flush', self.service.device.accept_cis_request(cis_handle) + ) + + def on_cis_establishment(self, cis_link: device.CisLink) -> None: + if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: + self.state = self.State.STREAMING + cis_link.acl_connection.abort_on( + 'flush', self.service.device.notify_subscribers(self, self.value) + ) + + def on_config_codec( + self, + target_latency: int, + target_phy: int, + codec_id: hci.CodingFormat, + codec_specific_configuration: bytes, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + self.State.IDLE, + self.State.CODEC_CONFIGURED, + self.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.max_transport_latency = target_latency + self.phy = target_phy + self.codec_id = codec_id + self.codec_specific_configuration = codec_specific_configuration + + self.state = self.State.CODEC_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_config_qos( + self, + cig_id: int, + cis_id: int, + sdu_interval: int, + framing: int, + phy: int, + max_sdu: int, + retransmission_number: int, + max_transport_latency: int, + presentation_delay: int, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.CODEC_CONFIGURED, + AseStateMachine.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.cig_id = cig_id + self.cis_id = cis_id + self.sdu_interval = sdu_interval + self.framing = framing + self.phy = phy + self.max_sdu = max_sdu + self.retransmission_number = retransmission_number + self.max_transport_latency = max_transport_latency + self.presentation_delay = presentation_delay + + self.state = self.State.QOS_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.QOS_CONFIGURED: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.metadata = metadata + self.state = self.State.ENABLING + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.ENABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.STREAMING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.DISABLING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.DISABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.QOS_CONFIGURED + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_update_metadata( + self, metadata: bytes + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.metadata = metadata + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.IDLE: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.RELEASING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + @property + def value(self): + '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' + + if self.state == self.State.CODEC_CONFIGURED: + additional_parameters = ( + struct.pack( + ' bytes: + return self.value + + +class AudioStreamControlService(gatt.TemplateService): + UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE + + ase_state_machines: Dict[int, AseStateMachine] + ase_control_point: gatt.Characteristic + + def __init__( + self, + device: device.Device, + source_ase_id: Sequence[int] = [], + sink_ase_id: Sequence[int] = [], + ) -> None: + self.device = device + self.ase_state_machines = { + id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) + for id in sink_ase_id + } | { + id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) + for id in source_ase_id + } # ASE state machines, by ASE ID + + for ase in self.ase_state_machines.values(): + print(ase.ase_id) + + self.ase_control_point = gatt.Characteristic( + uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.WRITE + | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.WRITEABLE, + value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), + ) + + super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) + + def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): + if ase := self.ase_state_machines.get(ase_id): + handler = getattr(ase, 'on_' + opcode.name.lower()) + return (ase_id, *handler(*args)) + else: + return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + + def on_write_ase_control_point(self, connection, data): + operation = ASE_Operation.from_bytes(data) + responses = [] + logging.debug(f'*** ASCS Write {operation} ***') + + if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: + for ase_id, *args in zip( + operation.ase_id, + operation.target_latency, + operation.target_phy, + operation.codec_id, + operation.codec_specific_configuration, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: + for ase_id, *args in zip( + operation.ase_id, + operation.cig_id, + operation.cis_id, + operation.sdu_interval, + operation.framing, + operation.phy, + operation.max_sdu, + operation.retransmission_number, + operation.max_transport_latency, + operation.presentation_delay, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.ENABLE, + ASE_Operation.Opcode.UPDATE_METADATA, + ): + for ase_id, *args in zip( + operation.ase_id, + operation.metadata, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.RECEIVER_START_READY, + ASE_Operation.Opcode.DISABLE, + ASE_Operation.Opcode.RECEIVER_STOP_READY, + ASE_Operation.Opcode.RELEASE, + ): + for ase_id in operation.ase_id: + responses.append(self.on_operation(operation.op_code, ase_id, [])) + + control_point_notification = bytes( + [operation.op_code, len(responses)] + ) + b''.join(map(bytes, responses)) + self.device.abort_on( + 'flush', + self.device.notify_subscribers( + self.ase_control_point, control_point_notification + ), + ) + + for ase_id, *_ in responses: + if ase := self.ase_state_machines.get(ase_id): + self.device.abort_on( + 'flush', + self.device.notify_subscribers(ase, ase.value), + ) + + # ----------------------------------------------------------------------------- # Client # ----------------------------------------------------------------------------- diff --git a/examples/leaudio.json b/examples/leaudio.json index 4b6edfce..c4c5a11c 100644 --- a/examples/leaudio.json +++ b/examples/leaudio.json @@ -1,5 +1,6 @@ { "name": "Bumble-LEA", "keystore": "JsonKeyStore", + "address": "F0:F1:F2:F3:F4:FA", "advertising_interval": 100 } diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 868b4f84..4dadec85 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -35,6 +35,7 @@ from bumble.profiles.bap import ( SupportedFrameDuration, PacRecord, PublishedAudioCapabilitiesService, + AudioStreamControlService, ) from bumble.transport import open_transport_or_link @@ -103,6 +104,8 @@ async def main() -> None: ) ) + device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) + advertising_data = bytes( AdvertisingData( [ @@ -110,6 +113,10 @@ async def main() -> None: AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble LE Audio', 'utf-8'), ), + ( + AdvertisingData.FLAGS, + bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]), + ), ( AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(PublishedAudioCapabilitiesService.UUID), From af5776222767dcbf28eb5a1043c62f3babf9150d Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Sat, 2 Dec 2023 18:52:48 +0800 Subject: [PATCH 04/16] Parse CodecSpecificConfiguration --- bumble/profiles/bap.py | 91 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 5 deletions(-) diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index a1cae1b7..ef94d547 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -553,6 +553,80 @@ class CodecSpecificCapabilities: ) +@dataclasses.dataclass +class CodecSpecificConfiguration: + '''See: + * Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures + * Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements + ''' + + class Type(enum.IntEnum): + # fmt: off + SAMPLING_FREQUENCY = 0x01 + FRAME_DURATION = 0x02 + AUDIO_CHANNEL_ALLOCATION = 0x03 + OCTETS_PER_FRAME = 0x04 + CODEC_FRAMES_PER_SDU = 0x05 + + sampling_frequency: SamplingFrequency + frame_duration: FrameDuration + audio_channel_allocation: AudioLocation + octets_per_codec_frame: int + codec_frames_per_sdu: int + + @classmethod + def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration: + offset = 0 + # Allowed default values. + audio_channel_allocation = AudioLocation.NOT_ALLOWED + codec_frames_per_sdu = 1 + while offset < len(data): + length, type = struct.unpack_from('BB', data, offset) + offset += 2 + value = int.from_bytes(data[offset : offset + length - 1], 'little') + offset += length - 1 + + if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY: + sampling_frequency = SamplingFrequency(value) + elif type == CodecSpecificConfiguration.Type.FRAME_DURATION: + frame_duration = FrameDuration(value) + elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION: + audio_channel_allocation = AudioLocation(value) + elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME: + octets_per_codec_frame = value + elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU: + codec_frames_per_sdu = value + + # It is expected here that if some fields are missing, an error should be raised. + return CodecSpecificConfiguration( + sampling_frequency=sampling_frequency, + frame_duration=frame_duration, + audio_channel_allocation=audio_channel_allocation, + octets_per_codec_frame=octets_per_codec_frame, + codec_frames_per_sdu=codec_frames_per_sdu, + ) + + def __bytes__(self) -> bytes: + return struct.pack( + ' Date: Sat, 2 Dec 2023 19:00:45 +0800 Subject: [PATCH 05/16] Setup data path after CIS established --- bumble/hci.py | 4 ++++ bumble/profiles/bap.py | 24 ++++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/bumble/hci.py b/bumble/hci.py index 8d5f9cd9..36c049c9 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4554,6 +4554,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command): See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command ''' + class Direction(enum.IntEnum): + HOST_TO_CONTROLLER = 0x00 + CONTROLLER_TO_HOST = 0x01 + connection_handle: int data_path_direction: int data_path_id: int diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index ef94d547..aad299d8 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -443,9 +443,9 @@ class AseReasonCode(enum.IntEnum): INVALID_ASE_CIS_MAPPING = 0x0A -class AudioRole(enum.Enum): - SINK = enum.auto() - SOURCE = enum.auto() +class AudioRole(enum.IntEnum): + SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST + SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER # ----------------------------------------------------------------------------- @@ -834,9 +834,21 @@ class AseStateMachine(gatt.Characteristic): def on_cis_establishment(self, cis_link: device.CisLink) -> None: if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: self.state = self.State.STREAMING - cis_link.acl_connection.abort_on( - 'flush', self.service.device.notify_subscribers(self, self.value) - ) + + async def post_cis_established(): + await self.service.device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=cis_link.handle, + data_path_direction=self.role, + data_path_id=0x00, # Fixed HCI + codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ) + ) + await self.service.device.notify_subscribers(self, self.value) + + cis_link.acl_connection.abort_on('flush', post_cis_established()) def on_config_codec( self, From 4d6822d31258d0af7386a83b11a8e06976915986 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 6 Dec 2023 16:33:55 +0800 Subject: [PATCH 06/16] Remove ISO data path on release --- bumble/profiles/bap.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index aad299d8..bc06dae0 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -765,6 +765,8 @@ class AseStateMachine(gatt.Characteristic): DISABLING = 0x05 RELEASING = 0x06 + cis_link: Optional[device.CisLink] = None + # Additional parameters in CODEC_CONFIGURED State preferred_framing = 0 # Unframed PDU supported preferred_phy = 0 @@ -834,6 +836,7 @@ class AseStateMachine(gatt.Characteristic): def on_cis_establishment(self, cis_link: device.CisLink) -> None: if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: self.state = self.State.STREAMING + self.cis_link = cis_link async def post_cis_established(): await self.service.device.send_command( @@ -979,6 +982,18 @@ class AseStateMachine(gatt.Characteristic): AseReasonCode.NONE, ) self.state = self.State.RELEASING + + async def remove_cis_async(): + await self.service.device.send_command( + hci.HCI_LE_Remove_ISO_Data_Path_Command( + connection_handle=self.cis_link.handle, + data_path_direction=self.role, + ) + ) + self.state = self.State.CODEC_CONFIGURED + await self.service.device.notify_subscribers(self, self.value) + + self.service.device.abort_on('flush', remove_cis_async()) return (AseResponseCode.SUCCESS, AseReasonCode.NONE) @property From 55596176c2fef1fbc3b0559c35580ca99c072cea Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 6 Dec 2023 16:25:41 +0800 Subject: [PATCH 07/16] ffplay routing --- examples/run_unicast_server.py | 47 ++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 4dadec85..e71cbeff 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -19,12 +19,14 @@ import asyncio import logging import sys import os +import struct from bumble.core import AdvertisingData -from bumble.device import Device +from bumble.device import Device, CisLink from bumble.hci import ( CodecID, CodingFormat, OwnAddressType, + HCI_IsoDataPacket, HCI_LE_Set_Extended_Advertising_Parameters_Command, ) from bumble.profiles.bap import ( @@ -115,7 +117,13 @@ async def main() -> None: ), ( AdvertisingData.FLAGS, - bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]), + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), ), ( AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, @@ -124,6 +132,41 @@ async def main() -> None: ] ) ) + subprocess = await asyncio.create_subprocess_shell( + f'dlc3 | ffplay pipe:0', + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdin = subprocess.stdin + assert stdin + + # Write a fake LC3 header to dlc3. + stdin.write( + bytes([0x1C, 0xCC]) # Header. + + struct.pack( + ' Date: Thu, 7 Dec 2023 19:25:50 +0800 Subject: [PATCH 08/16] Fix ASE state change --- bumble/profiles/bap.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index bc06dae0..d06dac54 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -33,6 +33,11 @@ from bumble import gatt from bumble import gatt_client +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -976,7 +981,7 @@ class AseStateMachine(gatt.Characteristic): return (AseResponseCode.SUCCESS, AseReasonCode.NONE) def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.IDLE: + if self.state == AseStateMachine.State.IDLE: return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseReasonCode.NONE, @@ -990,7 +995,7 @@ class AseStateMachine(gatt.Characteristic): data_path_direction=self.role, ) ) - self.state = self.State.CODEC_CONFIGURED + self.state = self.State.IDLE await self.service.device.notify_subscribers(self, self.value) self.service.device.abort_on('flush', remove_cis_async()) @@ -1101,7 +1106,7 @@ class AudioStreamControlService(gatt.TemplateService): def on_write_ase_control_point(self, connection, data): operation = ASE_Operation.from_bytes(data) responses = [] - logging.debug(f'*** ASCS Write {operation} ***') + logger.debug(f'*** ASCS Write {operation} ***') if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: for ase_id, *args in zip( From dd090c9e6be690d2c00402e1c927d374181d99fb Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 8 Dec 2023 11:00:44 +0800 Subject: [PATCH 09/16] Add ASCS tests --- bumble/controller.py | 12 ++ bumble/profiles/bap.py | 21 ++++ tests/bap_test.py | 251 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 284 insertions(+) diff --git a/bumble/controller.py b/bumble/controller.py index d035bcca..4ead098e 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -1263,3 +1263,15 @@ class Controller: See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command ''' return struct.pack(' None: SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000 @@ -85,6 +108,92 @@ def test_vendor_specific_pac_record() -> None: assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA +# ----------------------------------------------------------------------------- +def test_ASE_Config_Codec() -> None: + operation = ASE_Config_Codec( + ase_id=[1, 2], + target_latency=[3, 4], + target_phy=[5, 6], + codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)], + codec_specific_configuration=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Config_QOS() -> None: + operation = ASE_Config_QOS( + ase_id=[1, 2], + cig_id=[1, 2], + cis_id=[3, 4], + sdu_interval=[5, 6], + framing=[0, 1], + phy=[2, 3], + max_sdu=[4, 5], + retransmission_number=[6, 7], + max_transport_latency=[8, 9], + presentation_delay=[10, 11], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Enable() -> None: + operation = ASE_Enable( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Update_Metadata() -> None: + operation = ASE_Update_Metadata( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Disable() -> None: + operation = ASE_Disable(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Release() -> None: + operation = ASE_Release(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Receiver_Start_Ready() -> None: + operation = ASE_Receiver_Start_Ready(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Receiver_Stop_Ready() -> None: + operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_codec_specific_configuration() -> None: + SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000 + FRAME_SURATION = FrameDuration.DURATION_10000_US + AUDIO_LOCATION = AudioLocation.FRONT_LEFT + config = CodecSpecificConfiguration( + sampling_frequency=SAMPLE_FREQUENCY, + frame_duration=FRAME_SURATION, + audio_channel_allocation=AUDIO_LOCATION, + octets_per_codec_frame=60, + codec_frames_per_sdu=1, + ) + assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_pacs(): @@ -140,6 +249,148 @@ async def test_pacs(): ) +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_ascs(): + devices = TwoDevices() + devices[0].add_service( + AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2]) + ) + + await devices.setup_connection() + peer = device.Peer(devices.connections[1]) + ascs_client = await peer.discover_service_and_create_proxy( + AudioStreamControlServiceProxy + ) + + notifications = {1: asyncio.Queue(), 2: asyncio.Queue()} + + def on_notification(data: bytes, ase_id: int): + notifications[ase_id].put_nowait(data) + + # Should be idle + assert await ascs_client.sink_ase[0].read_value() == bytes( + [1, AseStateMachine.State.IDLE] + ) + assert await ascs_client.sink_ase[1].read_value() == bytes( + [2, AseStateMachine.State.IDLE] + ) + + # Subscribe + await ascs_client.sink_ase[0].subscribe( + functools.partial(on_notification, ase_id=1) + ) + await ascs_client.sink_ase[1].subscribe( + functools.partial(on_notification, ase_id=2) + ) + + # Config Codec + config = CodecSpecificConfiguration( + sampling_frequency=SamplingFrequency.FREQ_48000, + frame_duration=FrameDuration.DURATION_10000_US, + audio_channel_allocation=AudioLocation.FRONT_LEFT, + octets_per_codec_frame=120, + codec_frames_per_sdu=1, + ) + await ascs_client.ase_control_point.write_value( + ASE_Config_Codec( + ase_id=[1, 2], + target_latency=[3, 4], + target_phy=[5, 6], + codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)], + codec_specific_configuration=[config, config], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.CODEC_CONFIGURED] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.CODEC_CONFIGURED] + ) + + # Config QOS + await ascs_client.ase_control_point.write_value( + ASE_Config_QOS( + ase_id=[1, 2], + cig_id=[1, 2], + cis_id=[3, 4], + sdu_interval=[5, 6], + framing=[0, 1], + phy=[2, 3], + max_sdu=[4, 5], + retransmission_number=[6, 7], + max_transport_latency=[8, 9], + presentation_delay=[10, 11], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.QOS_CONFIGURED] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.QOS_CONFIGURED] + ) + + # Enable + await ascs_client.ase_control_point.write_value( + ASE_Enable( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.ENABLING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.ENABLING] + ) + + # CIS establishment + devices[0].emit( + 'cis_establishment', + device.CisLink( + device=devices[0], + acl_connection=devices.connections[0], + handle=5, + cis_id=3, + cig_id=1, + ), + ) + devices[0].emit( + 'cis_establishment', + device.CisLink( + device=devices[0], + acl_connection=devices.connections[0], + handle=6, + cis_id=4, + cig_id=2, + ), + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.STREAMING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.STREAMING] + ) + + # Release + await ascs_client.ase_control_point.write_value( + ASE_Release( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.RELEASING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.RELEASING] + ) + assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE]) + assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE]) + + await asyncio.sleep(0.001) + + # ----------------------------------------------------------------------------- async def run(): await test_pacs() From 81a6b1e097f238cdab733d1d36e868345ef0b9f7 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 8 Dec 2023 11:10:17 +0800 Subject: [PATCH 10/16] Replace 3.9 dict merger --- bumble/profiles/bap.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 63048f79..fcaf3f66 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -1075,11 +1075,14 @@ class AudioStreamControlService(gatt.TemplateService): ) -> None: self.device = device self.ase_state_machines = { - id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) - for id in sink_ase_id - } | { - id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) - for id in source_ase_id + **{ + id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) + for id in sink_ase_id + }, + **{ + id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) + for id in source_ase_id + }, } # ASE state machines, by ASE ID for ase in self.ase_state_machines.values(): From 085f163c92f768d0f72777500a01e83f5bd725c1 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 8 Dec 2023 10:14:38 -0800 Subject: [PATCH 11/16] add support for 2M phy --- .../google/bumble/btbench/L2capClient.kt | 52 +++++++++++++- .../google/bumble/btbench/L2capServer.kt | 5 +- .../google/bumble/btbench/MainActivity.kt | 68 +++++++++++++++---- .../com/github/google/bumble/btbench/Model.kt | 8 ++- .../google/bumble/btbench/RfcommServer.kt | 2 +- .../google/bumble/btbench/SocketClient.kt | 6 ++ .../google/bumble/btbench/SocketServer.kt | 5 +- .../google/bumble/remotehci/MainActivity.kt | 31 +++++---- 8 files changed, 139 insertions(+), 38 deletions(-) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt index 7722bb84..874bc26b 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt @@ -16,17 +16,63 @@ package com.github.google.bumble.btbench import android.annotation.SuppressLint import android.bluetooth.BluetoothAdapter -import java.io.IOException +import android.bluetooth.BluetoothDevice +import android.bluetooth.BluetoothGatt +import android.bluetooth.BluetoothGattCallback +import android.bluetooth.BluetoothProfile +import android.content.Context import java.util.logging.Logger -import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.l2cap-client") -class L2capClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) { +class L2capClient( + private val viewModel: AppViewModel, + val bluetoothAdapter: BluetoothAdapter, + val context: Context +) { @SuppressLint("MissingPermission") fun run() { viewModel.running = true val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress) + + val gatt = remoteDevice.connectGatt( + context, + false, + object : BluetoothGattCallback() { + override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) { + Log.info("MTU update: mtu=$mtu status=$status") + viewModel.mtu = mtu + } + + override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + override fun onConnectionStateChange( + gatt: BluetoothGatt?, status: Int, newState: Int + ) { + if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { + gatt.setPreferredPhy( + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_OPTION_NO_PREFERRED + ) + gatt.readPhy() + } + } + }, + BluetoothDevice.TRANSPORT_LE, + if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK + ) + val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm) val client = SocketClient(viewModel, socket) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt index 79c70045..76c297b3 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt @@ -30,7 +30,7 @@ private val Log = Logger.getLogger("btbench.l2cap-server") class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdapter: BluetoothAdapter) { @SuppressLint("MissingPermission") fun run() { - // Advertise to that the peer can find us and connect. + // Advertise so that the peer can find us and connect. val callback = object: AdvertiseCallback() { override fun onStartFailure(errorCode: Int) { Log.warning("failed to start advertising: $errorCode") @@ -50,13 +50,12 @@ class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdap val advertiseData = AdvertiseData.Builder().build() val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build() val advertiser = bluetoothAdapter.bluetoothLeAdvertiser - advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel() viewModel.l2capPsm = serverSocket.psm Log.info("psm = $serverSocket.psm") val server = SocketServer(viewModel, serverSocket) - server.run({ advertiser.stopAdvertising(callback) }) + server.run({ advertiser.stopAdvertising(callback) }, { advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) }) } } \ No newline at end of file diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt index 314f7465..60818371 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt @@ -26,23 +26,33 @@ import android.os.Bundle import androidx.activity.ComponentActivity import androidx.activity.compose.setContent import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Button import androidx.compose.material3.Divider import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Slider import androidx.compose.material3.Surface +import androidx.compose.material3.Switch import androidx.compose.material3.Text import androidx.compose.material3.TextField import androidx.compose.runtime.Composable +import androidx.compose.runtime.remember +import androidx.compose.ui.Alignment import androidx.compose.ui.ExperimentalComposeUiApi import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.FocusRequester +import androidx.compose.ui.focus.focusRequester +import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalSoftwareKeyboardController import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.input.ImeAction @@ -171,7 +181,7 @@ class MainActivity : ComponentActivity() { } private fun runL2capClient() { - val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it) } + val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it, baseContext) } l2capClient?.run() } @@ -199,9 +209,12 @@ fun MainView( runL2capServer: () -> Unit ) { BTBenchTheme { - // A surface container using the 'background' color from the theme + val scrollState = rememberScrollState() Surface( - modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState), + color = MaterialTheme.colorScheme.background ) { Column(modifier = Modifier.padding(horizontal = 16.dp)) { Text( @@ -212,28 +225,33 @@ fun MainView( ) Divider() val keyboardController = LocalSoftwareKeyboardController.current - TextField(label = { - Text(text = "Peer Bluetooth Address") - }, + val focusRequester = remember { FocusRequester() } + val focusManager = LocalFocusManager.current + TextField( + label = { + Text(text = "Peer Bluetooth Address") + }, value = appViewModel.peerBluetoothAddress, - modifier = Modifier.fillMaxWidth(), + modifier = Modifier.fillMaxWidth().focusRequester(focusRequester), keyboardOptions = KeyboardOptions.Default.copy( keyboardType = KeyboardType.Ascii, imeAction = ImeAction.Done ), onValueChange = { appViewModel.updatePeerBluetoothAddress(it) }, - keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() }) + keyboardActions = KeyboardActions(onDone = { + keyboardController?.hide() + focusManager.clearFocus() + }) ) Divider() TextField(label = { Text(text = "L2CAP PSM") }, value = appViewModel.l2capPsm.toString(), - modifier = Modifier.fillMaxWidth(), + modifier = Modifier.fillMaxWidth().focusRequester(focusRequester), keyboardOptions = KeyboardOptions.Default.copy( - keyboardType = KeyboardType.Number, - imeAction = ImeAction.Done + keyboardType = KeyboardType.Number, imeAction = ImeAction.Done ), onValueChange = { if (it.isNotEmpty()) { @@ -243,7 +261,11 @@ fun MainView( } } }, - keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() })) + keyboardActions = KeyboardActions(onDone = { + keyboardController?.hide() + focusManager.clearFocus() + }) + ) Divider() Slider( value = appViewModel.senderPacketCountSlider, onValueChange = { @@ -264,7 +286,19 @@ fun MainView( ActionButton( text = "Become Discoverable", onClick = becomeDiscoverable, true ) - Row() { + Row( + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text(text = "2M PHY") + Spacer(modifier = Modifier.padding(start = 8.dp)) + Switch( + checked = appViewModel.use2mPhy, + onCheckedChange = { appViewModel.use2mPhy = it } + ) + + } + Row { ActionButton( text = "RFCOMM Client", onClick = runRfcommClient, !appViewModel.running ) @@ -272,7 +306,7 @@ fun MainView( text = "RFCOMM Server", onClick = runRfcommServer, !appViewModel.running ) } - Row() { + Row { ActionButton( text = "L2CAP Client", onClick = runL2capClient, !appViewModel.running ) @@ -281,6 +315,12 @@ fun MainView( ) } Divider() + Text( + text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else "" + ) + Text( + text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else "" + ) Text( text = "Packets Sent: ${appViewModel.packetsSent}" ) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt index 93755e40..b709be32 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt @@ -32,6 +32,10 @@ class AppViewModel : ViewModel() { private var preferences: SharedPreferences? = null var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS) var l2capPsm by mutableStateOf(0) + var use2mPhy by mutableStateOf(true) + var mtu by mutableStateOf(0) + var rxPhy by mutableStateOf(0) + var txPhy by mutableStateOf(0) var senderPacketCountSlider by mutableFloatStateOf(0.0F) var senderPacketSizeSlider by mutableFloatStateOf(0.0F) var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT) @@ -116,7 +120,7 @@ class AppViewModel : ViewModel() { } fun updateSenderPacketSizeSlider() { - if (senderPacketSize <= 1) { + if (senderPacketSize <= 16) { senderPacketSizeSlider = 0.0F } else if (senderPacketSize <= 256) { senderPacketSizeSlider = 0.02F @@ -138,7 +142,7 @@ class AppViewModel : ViewModel() { fun updateSenderPacketSize() { if (senderPacketSizeSlider < 0.1F) { - senderPacketSize = 1 + senderPacketSize = 16 } else if (senderPacketSizeSlider < 0.3F) { senderPacketSize = 256 } else if (senderPacketSizeSlider < 0.5F) { diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt index f06736b2..69612c55 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt @@ -30,6 +30,6 @@ class RfcommServer(private val viewModel: AppViewModel, val bluetoothAdapter: Bl ) val server = SocketServer(viewModel, serverSocket) - server.run({}) + server.run({}, {}) } } \ No newline at end of file diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt index cc5058e4..28c53542 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt @@ -22,6 +22,8 @@ import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.socket-client") +private const val DEFAULT_STARTUP_DELAY = 1000 + class SocketClient(private val viewModel: AppViewModel, private val socket: BluetoothSocket) { @SuppressLint("MissingPermission") fun run() { @@ -56,6 +58,10 @@ class SocketClient(private val viewModel: AppViewModel, private val socket: Blue socketDataSource.receive() } + Log.info("Startup delay: $DEFAULT_STARTUP_DELAY") + Thread.sleep(DEFAULT_STARTUP_DELAY.toLong()); + Log.info("Starting to send") + sender.run() cleanup() } diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt index 3f9c3e1d..e461617d 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt @@ -22,14 +22,13 @@ import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.socket-server") class SocketServer(private val viewModel: AppViewModel, private val serverSocket: BluetoothServerSocket) { - fun run(onTerminate: () -> Unit) { + fun run(onConnected: () -> Unit, onDisconnected: () -> Unit) { var aborted = false viewModel.running = true fun cleanup() { serverSocket.close() viewModel.running = false - onTerminate() } thread(name = "SocketServer") { @@ -38,6 +37,7 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket serverSocket.close() } Log.info("waiting for connection...") + onDisconnected() val socket = try { serverSocket.accept() } catch (error: IOException) { @@ -46,6 +46,7 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket return@thread } Log.info("got connection") + onConnected() viewModel.aborter = { aborted = true diff --git a/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt b/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt index ebdc708a..493b7e50 100644 --- a/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt +++ b/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt @@ -10,8 +10,10 @@ import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Button import androidx.compose.material3.Divider import androidx.compose.material3.ExperimentalMaterial3Api @@ -71,7 +73,7 @@ class AppViewModel : ViewModel(), HciProxy.Listener { this.tcpPort = tcpPort // Save the port to the preferences - with (preferences!!.edit()) { + with(preferences!!.edit()) { putString(TCP_PORT_PREF_KEY, tcpPort.toString()) apply() } @@ -138,7 +140,8 @@ class MainActivity : ComponentActivity() { log.warning("Exception while running HCI Server: $error") } catch (error: HalException) { log.warning("HAL exception: ${error.message}") - appViewModel.message = "Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell." + appViewModel.message = + "Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell." } log.info("HCI Proxy thread ended") appViewModel.canStart = true @@ -157,9 +160,12 @@ fun ActionButton(text: String, onClick: () -> Unit, enabled: Boolean) { @Composable fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { RemoteHCITheme { - // A surface container using the 'background' color from the theme + val scrollState = rememberScrollState() Surface( - modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState), + color = MaterialTheme.colorScheme.background ) { Column(modifier = Modifier.padding(horizontal = 16.dp)) { Text( @@ -174,13 +180,15 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { ) Divider() val keyboardController = LocalSoftwareKeyboardController.current - TextField( - label = { - Text(text = "TCP Port") - }, + TextField(label = { + Text(text = "TCP Port") + }, value = appViewModel.tcpPort.toString(), modifier = Modifier.fillMaxWidth(), - keyboardOptions = KeyboardOptions.Default.copy(keyboardType = KeyboardType.Number, imeAction = ImeAction.Done), + keyboardOptions = KeyboardOptions.Default.copy( + keyboardType = KeyboardType.Number, + imeAction = ImeAction.Done + ), onValueChange = { if (it.isNotEmpty()) { val tcpPort = it.toIntOrNull() @@ -189,10 +197,7 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { } } }, - keyboardActions = KeyboardActions( - onDone = {keyboardController?.hide()} - ) - ) + keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() })) Divider() val connectState = if (appViewModel.hostConnected) "CONNECTED" else "DISCONNECTED" Text( From 62a8ced44772cde4a175da3dd69d3de850b2751e Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 8 Dec 2023 17:28:57 -0800 Subject: [PATCH 12/16] support drivers that can't use reset directly. --- bumble/drivers/__init__.py | 21 +++++++++++++---- bumble/drivers/rtk.py | 11 +++++++-- bumble/host.py | 38 +++++++++++++++++------------- bumble/transport/__init__.py | 38 ++++++++++++++++++++++++++---- bumble/transport/common.py | 2 +- docs/mkdocs/src/drivers/index.md | 9 +++++++ docs/mkdocs/src/drivers/realtek.md | 7 ++++-- 7 files changed, 94 insertions(+), 32 deletions(-) diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index d8ea06e6..0a38f086 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -60,12 +60,23 @@ class Driver(abc.ABC): # Functions # ----------------------------------------------------------------------------- async def get_driver_for_host(host): - """Probe all known diver classes until one returns a valid instance for a host, - or none is found. + """Probe diver classes until one returns a valid instance for a host, or none is + found. + If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - if driver := await rtk.Driver.for_host(host): - logger.debug("Instantiated RTK driver") - return driver + driver_classes = {"rtk": rtk.Driver} + if driver_name := host.hci_metadata.get("driver"): + # Only probe a single driver + probe_list = [driver_name] + else: + # Probe all drivers + probe_list = driver_classes.keys() + + for driver_name in probe_list: + logger.debug(f"Probing {driver_name} driver class") + if driver := await rtk.Driver.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver return None diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index f78a14d3..0b64e0cd 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -470,8 +470,12 @@ class Driver: logger.debug("USB metadata not found") return False - vendor_id = host.hci_metadata.get("vendor_id", None) - product_id = host.hci_metadata.get("product_id", None) + if host.hci_metadata.get('driver') == 'rtk': + # Forced driver + return True + + vendor_id = host.hci_metadata.get("vendor_id") + product_id = host.hci_metadata.get("product_id") if vendor_id is None or product_id is None: logger.debug("USB metadata not sufficient") return False @@ -486,6 +490,9 @@ class Driver: @classmethod async def driver_info_for_host(cls, host): + await host.send_command(HCI_Reset_Command(), check_result=True) + host.ready = True # Needed to let the host know the controller is ready. + response = await host.send_command( HCI_Read_Local_Version_Information_Command(), check_result=True ) diff --git a/bumble/host.py b/bumble/host.py index 3ae2280b..190ab89e 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,7 @@ import collections import logging import struct -from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING from bumble.colors import color from bumble.l2cap import L2CAP_PDU @@ -124,7 +124,8 @@ class Connection: class Host(AbortableEventEmitter): connections: Dict[int, Connection] acl_packet_queue: collections.deque[HCI_AclDataPacket] - hci_sink: TransportSink + hci_sink: Optional[TransportSink] = None + hci_metadata: Dict[str, Any] long_term_key_provider: Optional[ Callable[[int, bytes, int], Awaitable[Optional[bytes]]] ] @@ -137,9 +138,8 @@ class Host(AbortableEventEmitter): ) -> None: super().__init__() - self.hci_metadata = None + self.hci_metadata = {} self.ready = False # True when we can accept incoming packets - self.reset_done = False self.connections = {} # Connections, by connection handle self.pending_command = None self.pending_response = None @@ -162,10 +162,7 @@ class Host(AbortableEventEmitter): # Connect to the source and sink if specified if controller_source: - controller_source.set_packet_sink(self) - self.hci_metadata = getattr( - controller_source, 'metadata', self.hci_metadata - ) + self.set_packet_source(controller_source) if controller_sink: self.set_packet_sink(controller_sink) @@ -200,17 +197,21 @@ class Host(AbortableEventEmitter): self.ready = False await self.flush() - await self.send_command(HCI_Reset_Command(), check_result=True) - self.ready = True - # Instantiate and init a driver for the host if needed. # NOTE: we don't keep a reference to the driver here, because we don't # currently have a need for the driver later on. But if the driver interface # evolves, it may be required, then, to store a reference to the driver in # an object property. + reset_needed = True if driver_factory is not None: if driver := await driver_factory(self): await driver.init_controller() + reset_needed = False + + # Send a reset command unless a driver has already done so. + if reset_needed: + await self.send_command(HCI_Reset_Command(), check_result=True) + self.ready = True response = await self.send_command( HCI_Read_Local_Supported_Commands_Command(), check_result=True @@ -313,25 +314,28 @@ class Host(AbortableEventEmitter): ) ) - self.reset_done = True - @property - def controller(self) -> TransportSink: + def controller(self) -> Optional[TransportSink]: return self.hci_sink @controller.setter - def controller(self, controller): + def controller(self, controller) -> None: self.set_packet_sink(controller) if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink: TransportSink) -> None: + def set_packet_sink(self, sink: Optional[TransportSink]) -> None: self.hci_sink = sink + def set_packet_source(self, source: TransportSource) -> None: + source.set_packet_sink(self) + self.hci_metadata = getattr(source, 'metadata', self.hci_metadata) + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) + if self.hci_sink: + self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index bc0766b2..4822dfe1 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from typing import Optional from .common import Transport, AsyncPipeSink, SnoopingTransport from ..snoop import create_snooper @@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport: async def open_transport(name: str) -> Transport: """ Open a transport by name. - The name must be : - Where depend on the type (and may be empty for some types). + The name must be : + Where depend on the type (and may be empty for some types), and + is either omitted, or a ,-separated list of = pairs, + enclosed in []. + If there are not metadata or parameter, the : after the may be omitted. + Examples: + * usb:0 + * usb:[driver=rtk]0 + * android-netsim + The supported types are: * serial * udp @@ -71,15 +80,34 @@ async def open_transport(name: str) -> Transport: * android-netsim """ - return _wrap_transport(await _open_transport(name)) + scheme, *tail = name.split(':', 1) + spec = tail[0] if tail else None + if spec: + # Metadata may precede the spec + if spec.startswith('['): + metadata_str, *tail = spec[1:].split(']') + spec = tail[0] if tail else None + metadata = dict([entry.split('=') for entry in metadata_str.split(',')]) + else: + metadata = None + + transport = await _open_transport(scheme, spec) + if metadata: + transport.source.metadata = { # type: ignore[attr-defined] + **metadata, + **getattr(transport.source, 'metadata', {}), + } + # pylint: disable=line-too-long + logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined] + + return _wrap_transport(transport) # ----------------------------------------------------------------------------- -async def _open_transport(name: str) -> Transport: +async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: # pylint: disable=import-outside-toplevel # pylint: disable=too-many-return-statements - scheme, *spec = name.split(':', 1) if scheme == 'serial' and spec: from .serial import open_serial_transport diff --git a/bumble/transport/common.py b/bumble/transport/common.py index ace04da5..f767f54f 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -21,7 +21,7 @@ import struct import asyncio import logging import io -from typing import ContextManager, Tuple, Optional, Protocol, Dict +from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from bumble import hci from bumble.colors import color diff --git a/docs/mkdocs/src/drivers/index.md b/docs/mkdocs/src/drivers/index.md index a904e006..cb0a981e 100644 --- a/docs/mkdocs/src/drivers/index.md +++ b/docs/mkdocs/src/drivers/index.md @@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly. This may include, for instance, loading a Firmware image or patch, loading a configuration. +By default, drivers will be automatically probed to determine if they should be +used with particular HCI controller. +When the transport for an HCI controller is instantiated from a transport name, +a driver may also be forced by specifying ``driver=`` in the optional +metadata portion of the transport name. For example, +``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the +first USB device, even if a normal probe would not have selected it based on the +USB vendor ID and product ID. + Drivers included in the module are: * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. \ No newline at end of file diff --git a/docs/mkdocs/src/drivers/realtek.md b/docs/mkdocs/src/drivers/realtek.md index acbce490..599ce048 100644 --- a/docs/mkdocs/src/drivers/realtek.md +++ b/docs/mkdocs/src/drivers/realtek.md @@ -1,13 +1,16 @@ REALTEK DRIVER ============== -This driver supports loading firmware images and optional config data to +This driver supports loading firmware images and optional config data to USB dongles with a Realtek chipset. A number of USB dongles are supported, but likely not all. -When using a USB dongle, the USB product ID and manufacturer ID are used +When using a USB dongle, the USB product ID and vendor ID are used to find whether a matching set of firmware image and config data is needed for that specific model. If a match exists, the driver will try load the firmware image and, if needed, config data. +Alternatively, the metadata property ``driver=rtk`` may be specified in a transport +name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just +``usb:0`` for the first USB device). The driver will look for those files by name, in order, in: * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR` From d35643524eb46de8efe4c7d53e8f34831d828700 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 8 Dec 2023 18:46:25 -0800 Subject: [PATCH 13/16] allow specifying the address type --- .../github/google/bumble/btbench/L2capClient.kt | 16 +++++++++++++++- .../com/github/google/bumble/btbench/Model.kt | 5 +++-- .../github/google/bumble/btbench/RfcommClient.kt | 3 ++- .../github/google/bumble/btbench/SocketClient.kt | 2 +- .../github/google/bumble/btbench/SocketServer.kt | 2 +- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt index 874bc26b..228a741e 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capClient.kt @@ -21,6 +21,7 @@ import android.bluetooth.BluetoothGatt import android.bluetooth.BluetoothGattCallback import android.bluetooth.BluetoothProfile import android.content.Context +import android.os.Build import java.util.logging.Logger private val Log = Logger.getLogger("btbench.l2cap-client") @@ -33,7 +34,20 @@ class L2capClient( @SuppressLint("MissingPermission") fun run() { viewModel.running = true - val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress) + val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P") + val address = viewModel.peerBluetoothAddress.take(17) + val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + bluetoothAdapter.getRemoteLeDevice( + address, + if (addressIsPublic) { + BluetoothDevice.ADDRESS_TYPE_PUBLIC + } else { + BluetoothDevice.ADDRESS_TYPE_RANDOM + } + ) + } else { + bluetoothAdapter.getRemoteDevice(address) + } val gatt = remoteDevice.connectGatt( context, diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt index b709be32..35ee8da5 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt @@ -68,11 +68,12 @@ class AppViewModel : ViewModel() { } fun updatePeerBluetoothAddress(peerBluetoothAddress: String) { - this.peerBluetoothAddress = peerBluetoothAddress + val address = peerBluetoothAddress.uppercase() + this.peerBluetoothAddress = address // Save the address to the preferences with(preferences!!.edit()) { - putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, peerBluetoothAddress) + putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, address) apply() } } diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt index 644a5bda..e976c429 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt @@ -25,7 +25,8 @@ private val Log = Logger.getLogger("btbench.rfcomm-client") class RfcommClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) { @SuppressLint("MissingPermission") fun run() { - val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress) + val address = viewModel.peerBluetoothAddress.take(17) + val remoteDevice = bluetoothAdapter.getRemoteDevice(address) val socket = remoteDevice.createInsecureRfcommSocketToServiceRecord( DEFAULT_RFCOMM_UUID ) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt index 28c53542..bd5b7f4a 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt @@ -22,7 +22,7 @@ import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.socket-client") -private const val DEFAULT_STARTUP_DELAY = 1000 +private const val DEFAULT_STARTUP_DELAY = 3000 class SocketClient(private val viewModel: AppViewModel, private val socket: BluetoothSocket) { @SuppressLint("MissingPermission") diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt index e461617d..e83a47f2 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt @@ -45,7 +45,7 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket cleanup() return@thread } - Log.info("got connection") + Log.info("got connection from ${socket.remoteDevice.address}") onConnected() viewModel.aborter = { From b083cc99ade6aa129fed40e15f6c433a90feaa8b Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 8 Dec 2023 18:57:02 -0800 Subject: [PATCH 14/16] fix spec parsing --- bumble/transport/__init__.py | 32 ++++++++++++++-------------- bumble/transport/android_emulator.py | 2 +- bumble/transport/hci_socket.py | 5 +---- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index 4822dfe1..065e6964 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -111,75 +111,75 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: if scheme == 'serial' and spec: from .serial import open_serial_transport - return await open_serial_transport(spec[0]) + return await open_serial_transport(spec) if scheme == 'udp' and spec: from .udp import open_udp_transport - return await open_udp_transport(spec[0]) + return await open_udp_transport(spec) if scheme == 'tcp-client' and spec: from .tcp_client import open_tcp_client_transport - return await open_tcp_client_transport(spec[0]) + return await open_tcp_client_transport(spec) if scheme == 'tcp-server' and spec: from .tcp_server import open_tcp_server_transport - return await open_tcp_server_transport(spec[0]) + return await open_tcp_server_transport(spec) if scheme == 'ws-client' and spec: from .ws_client import open_ws_client_transport - return await open_ws_client_transport(spec[0]) + return await open_ws_client_transport(spec) if scheme == 'ws-server' and spec: from .ws_server import open_ws_server_transport - return await open_ws_server_transport(spec[0]) + return await open_ws_server_transport(spec) if scheme == 'pty': from .pty import open_pty_transport - return await open_pty_transport(spec[0] if spec else None) + return await open_pty_transport(spec) if scheme == 'file': from .file import open_file_transport assert spec is not None - return await open_file_transport(spec[0]) + return await open_file_transport(spec) if scheme == 'vhci': from .vhci import open_vhci_transport - return await open_vhci_transport(spec[0] if spec else None) + return await open_vhci_transport(spec) if scheme == 'hci-socket': from .hci_socket import open_hci_socket_transport - return await open_hci_socket_transport(spec[0] if spec else None) + return await open_hci_socket_transport(spec) if scheme == 'usb': from .usb import open_usb_transport - assert spec is not None - return await open_usb_transport(spec[0]) + assert spec + return await open_usb_transport(spec) if scheme == 'pyusb': from .pyusb import open_pyusb_transport - assert spec is not None - return await open_pyusb_transport(spec[0]) + assert spec + return await open_pyusb_transport(spec) if scheme == 'android-emulator': from .android_emulator import open_android_emulator_transport - return await open_android_emulator_transport(spec[0] if spec else None) + return await open_android_emulator_transport(spec) if scheme == 'android-netsim': from .android_netsim import open_android_netsim_transport - return await open_android_netsim_transport(spec[0] if spec else None) + return await open_android_netsim_transport(spec) raise ValueError('unknown transport scheme') diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 8d19a9e2..9cd7ec21 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: mode = 'host' server_host = 'localhost' server_port = '8554' - if spec is not None: + if spec: params = spec.split(',') for param in params: if param.startswith('mode='): diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index df9e885a..41250433 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport: ) from error # Compute the adapter index - if spec is None: - adapter_index = 0 - else: - adapter_index = int(spec) + adapter_index = int(spec) if spec else 0 # Bind the socket # NOTE: since Python doesn't support binding with the required address format (yet), From f911163e49512e79d91da391e406050f555e7457 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Tue, 12 Dec 2023 00:36:24 +0800 Subject: [PATCH 15/16] Improve ASCS logging --- bumble/profiles/bap.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index fcaf3f66..4785997b 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -807,7 +807,7 @@ class AseStateMachine(gatt.Characteristic): ) -> None: self.service = service self.ase_id = ase_id - self.state = AseStateMachine.State.IDLE + self._state = AseStateMachine.State.IDLE self.role = role uuid = ( @@ -1001,6 +1001,15 @@ class AseStateMachine(gatt.Characteristic): self.service.device.abort_on('flush', remove_cis_async()) return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') + self._state = new_state + @property def value(self): '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' @@ -1060,6 +1069,12 @@ class AseStateMachine(gatt.Characteristic): def on_read(self, _: device.Connection) -> bytes: return self.value + def __str__(self) -> str: + return ( + f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' + f'state={self._state.name})' + ) + class AudioStreamControlService(gatt.TemplateService): UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE @@ -1085,9 +1100,6 @@ class AudioStreamControlService(gatt.TemplateService): }, } # ASE state machines, by ASE ID - for ase in self.ase_state_machines.values(): - print(ase.ase_id) - self.ase_control_point = gatt.Characteristic( uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, properties=gatt.Characteristic.Properties.WRITE From 98ed772e8a606e3efeb06d638ea34c3ab4de91dd Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Mon, 11 Dec 2023 17:52:04 -0800 Subject: [PATCH 16/16] address PR comments and add some typing --- bumble/drivers/__init__.py | 49 +++++++++++++------------------------- bumble/drivers/common.py | 45 ++++++++++++++++++++++++++++++++++ bumble/drivers/rtk.py | 4 ++-- 3 files changed, 64 insertions(+), 34 deletions(-) create mode 100644 bumble/drivers/common.py diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index 0a38f086..b5712e66 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -19,12 +19,17 @@ like loading firmware after a cold start. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import abc +from __future__ import annotations import logging import pathlib import platform -from . import rtk +from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING +from . import rtk +from .common import Driver + +if TYPE_CHECKING: + from bumble.host import Host # ----------------------------------------------------------------------------- # Logging @@ -32,39 +37,16 @@ from . import rtk logger = logging.getLogger(__name__) -# ----------------------------------------------------------------------------- -# Classes -# ----------------------------------------------------------------------------- -class Driver(abc.ABC): - """Base class for drivers.""" - - @staticmethod - async def for_host(_host): - """Return a driver instance for a host. - - Args: - host: Host object for which a driver should be created. - - Returns: - A Driver instance if a driver should be instantiated for this host, or - None if no driver instance of this class is needed. - """ - return None - - @abc.abstractmethod - async def init_controller(self): - """Initialize the controller.""" - - # ----------------------------------------------------------------------------- # Functions # ----------------------------------------------------------------------------- -async def get_driver_for_host(host): +async def get_driver_for_host(host: Host) -> Optional[Driver]: """Probe diver classes until one returns a valid instance for a host, or none is found. If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - driver_classes = {"rtk": rtk.Driver} + driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver} + probe_list: Iterable[str] if driver_name := host.hci_metadata.get("driver"): # Only probe a single driver probe_list = [driver_name] @@ -73,10 +55,13 @@ async def get_driver_for_host(host): probe_list = driver_classes.keys() for driver_name in probe_list: - logger.debug(f"Probing {driver_name} driver class") - if driver := await rtk.Driver.for_host(host): - logger.debug(f"Instantiated {driver_name} driver") - return driver + if driver_class := driver_classes.get(driver_name): + logger.debug(f"Probing driver class: {driver_name}") + if driver := await driver_class.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver + else: + logger.debug(f"Skipping unknown driver class: {driver_name}") return None diff --git a/bumble/drivers/common.py b/bumble/drivers/common.py new file mode 100644 index 00000000..a4c0427c --- /dev/null +++ b/bumble/drivers/common.py @@ -0,0 +1,45 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common types for drivers. +""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import abc + + +# ----------------------------------------------------------------------------- +# Classes +# ----------------------------------------------------------------------------- +class Driver(abc.ABC): + """Base class for drivers.""" + + @staticmethod + async def for_host(_host): + """Return a driver instance for a host. + + Args: + host: Host object for which a driver should be created. + + Returns: + A Driver instance if a driver should be instantiated for this host, or + None if no driver instance of this class is needed. + """ + return None + + @abc.abstractmethod + async def init_controller(self): + """Initialize the controller.""" diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index 0b64e0cd..4a9034db 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -41,7 +41,7 @@ from bumble.hci import ( HCI_Reset_Command, HCI_Read_Local_Version_Information_Command, ) - +from bumble.drivers import common # ----------------------------------------------------------------------------- # Logging @@ -285,7 +285,7 @@ class Firmware: ) -class Driver: +class Driver(common.Driver): @dataclass class DriverInfo: rom: int