diff --git a/bumble/controller.py b/bumble/controller.py index d035bcca..4ead098e 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -1263,3 +1263,15 @@ class Controller: See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command ''' return struct.pack(' None: + await self.device.stop_legacy_advertising() + + +# ----------------------------------------------------------------------------- +@dataclass +class ExtendedAdvertiser(CompositeEventEmitter): + device: Device + handle: int + advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties + own_address_type: OwnAddressType + auto_restart: bool + advertising_data: Optional[bytes] + scan_response_data: Optional[bytes] + + def __post_init__(self) -> None: + super().__init__() + + async def stop(self) -> None: + await self.device.stop_extended_advertising(self.handle) + + # ----------------------------------------------------------------------------- class LePhyOptions: # Coded PHY preference @@ -658,6 +690,9 @@ class Connection(CompositeEventEmitter): gatt_client: gatt_client.Client pairing_peer_io_capability: Optional[int] pairing_peer_authentication_requirements: Optional[int] + advertiser_after_disconnection: Union[ + LegacyAdvertiser, ExtendedAdvertiser, None + ] = None @composite_listener class Listener: @@ -1063,7 +1098,8 @@ class Device(CompositeEventEmitter): ] advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] config: DeviceConfiguration - extended_advertising_handles: Set[int] + legacy_advertiser: Optional[LegacyAdvertiser] + extended_advertisers: Dict[int, ExtendedAdvertiser] sco_links: Dict[int, ScoLink] cis_links: Dict[int, CisLink] _pending_cis: Dict[int, Tuple[int, int]] @@ -1141,10 +1177,7 @@ class Device(CompositeEventEmitter): self._host = None self.powered_on = False - self.advertising = False - self.advertising_type = None self.auto_restart_inquiry = True - self.auto_restart_advertising = False self.command_timeout = 10 # seconds self.gatt_server = gatt_server.Server(self) self.sdp_server = sdp.Server(self) @@ -1168,10 +1201,10 @@ class Device(CompositeEventEmitter): self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY - self.extended_advertising_handles = set() + self.legacy_advertiser = None + self.extended_advertisers = {} # Own address type cache - self.advertising_own_address_type = None self.connect_own_address_type = None # Use the initial config or a default @@ -1579,6 +1612,7 @@ class Device(CompositeEventEmitter): return self.host.supports_le_feature(feature_map[phy]) + @deprecated("Please use start_legacy_advertising.") async def start_advertising( self, advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, @@ -1586,15 +1620,49 @@ class Device(CompositeEventEmitter): own_address_type: int = OwnAddressType.RANDOM, auto_restart: bool = False, ) -> None: + await self.start_legacy_advertising( + advertising_type=advertising_type, + target=target, + own_address_type=OwnAddressType(own_address_type), + auto_restart=auto_restart, + ) + + async def start_legacy_advertising( + self, + advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, + target: Optional[Address] = None, + own_address_type: OwnAddressType = OwnAddressType.RANDOM, + auto_restart: bool = False, + advertising_data: Optional[bytes] = None, + scan_response_data: Optional[bytes] = None, + ) -> LegacyAdvertiser: + """Starts an legacy advertisement. + + Args: + advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command. + target: Directed advertising target. Directed type should be set in advertising_type arg. + own_address_type: own address type to use in the advertising. + auto_restart: whether the advertisement will be restarted after disconnection. + scan_response_data: raw scan response. + advertising_data: raw advertising data. + + Returns: + LegacyAdvertiser object containing the metadata of advertisement. + """ + if self.extended_advertisers: + logger.warning( + 'Trying to start Legacy and Extended Advertising at the same time!' + ) + # If we're advertising, stop first - if self.advertising: + if self.legacy_advertiser: await self.stop_advertising() # Set/update the advertising data if the advertising type allows it if advertising_type.has_data: await self.send_command( HCI_LE_Set_Advertising_Data_Command( - advertising_data=self.advertising_data + advertising_data=advertising_data or self.advertising_data or b'' ), check_result=True, ) @@ -1603,7 +1671,9 @@ class Device(CompositeEventEmitter): if advertising_type.is_scannable: await self.send_command( HCI_LE_Set_Scan_Response_Data_Command( - scan_response_data=self.scan_response_data + scan_response_data=scan_response_data + or self.scan_response_data + or b'' ), check_result=True, ) @@ -1640,45 +1710,57 @@ class Device(CompositeEventEmitter): check_result=True, ) - self.advertising_type = advertising_type - self.advertising_own_address_type = own_address_type - self.advertising = True - self.auto_restart_advertising = auto_restart + self.legacy_advertiser = LegacyAdvertiser( + device=self, + advertising_type=advertising_type, + own_address_type=own_address_type, + auto_restart=auto_restart, + advertising_data=advertising_data, + scan_response_data=scan_response_data, + ) + return self.legacy_advertiser + @deprecated("Please use stop_legacy_advertising.") async def stop_advertising(self) -> None: + await self.stop_legacy_advertising() + + async def stop_legacy_advertising(self) -> None: # Disable advertising - if self.advertising: + if self.legacy_advertiser: await self.send_command( HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), check_result=True, ) - self.advertising_type = None - self.advertising_own_address_type = None - self.advertising = False - self.auto_restart_advertising = False + self.legacy_advertiser = None @experimental('Extended Advertising is still experimental - Might be changed soon.') async def start_extended_advertising( self, advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING, target: Address = Address.ANY, - own_address_type: int = OwnAddressType.RANDOM, - scan_response: Optional[bytes] = None, + own_address_type: OwnAddressType = OwnAddressType.RANDOM, + auto_restart: bool = True, advertising_data: Optional[bytes] = None, - ) -> int: + scan_response_data: Optional[bytes] = None, + ) -> ExtendedAdvertiser: """Starts an extended advertising set. Args: advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command target: Directed advertising target. Directed property should be set in advertising_properties arg. own_address_type: own address type to use in the advertising. - scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent. + auto_restart: whether the advertisement will be restarted after disconnection. advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent. + scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent. Returns: - Handle of the new advertising set. + ExtendedAdvertiser object containing the metadata of advertisement. """ + if self.legacy_advertiser: + logger.warning( + 'Trying to start Legacy and Extended Advertising at the same time!' + ) adv_handle = -1 # Find a free handle @@ -1686,7 +1768,7 @@ class Device(CompositeEventEmitter): DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE, DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1, ): - if i not in self.extended_advertising_handles: + if i not in self.extended_advertisers: adv_handle = i break @@ -1733,13 +1815,13 @@ class Device(CompositeEventEmitter): ) # Set the scan response if present - if scan_response is not None: + if scan_response_data is not None: await self.send_command( HCI_LE_Set_Extended_Scan_Response_Data_Command( advertising_handle=adv_handle, operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, fragment_preference=0x01, # Should not fragment - scan_response_data=scan_response, + scan_response_data=scan_response_data, ), check_result=True, ) @@ -1774,8 +1856,16 @@ class Device(CompositeEventEmitter): ) raise error - self.extended_advertising_handles.add(adv_handle) - return adv_handle + advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser( + device=self, + handle=adv_handle, + advertising_properties=advertising_properties, + own_address_type=own_address_type, + auto_restart=auto_restart, + advertising_data=advertising_data, + scan_response_data=scan_response_data, + ) + return advertiser @experimental('Extended Advertising is still experimental - Might be changed soon.') async def stop_extended_advertising(self, adv_handle: int) -> None: @@ -1799,11 +1889,11 @@ class Device(CompositeEventEmitter): HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), check_result=True, ) - self.extended_advertising_handles.remove(adv_handle) + del self.extended_advertisers[adv_handle] @property def is_advertising(self): - return self.advertising + return self.legacy_advertiser or self.extended_advertisers async def start_scanning( self, @@ -3144,13 +3234,18 @@ class Device(CompositeEventEmitter): # Guess which own address type is used for this connection. # This logic is somewhat correct but may need to be improved # when multiple advertising are run simultaneously. + advertiser = None if self.connect_own_address_type is not None: own_address_type = self.connect_own_address_type + elif self.legacy_advertiser: + own_address_type = self.legacy_advertiser.own_address_type + # Store advertiser for restarting - it's only required for legacy, since + # extended advertisement produces HCI_Advertising_Set_Terminated. + if self.legacy_advertiser.auto_restart: + advertiser = self.legacy_advertiser else: - own_address_type = self.advertising_own_address_type - - # We are no longer advertising - self.advertising = False + # For extended advertisement, determining own address type later. + own_address_type = OwnAddressType.RANDOM if own_address_type in ( OwnAddressType.PUBLIC, @@ -3172,6 +3267,7 @@ class Device(CompositeEventEmitter): connection_parameters, ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY), ) + connection.advertiser_after_disconnection = advertiser self.connections[connection_handle] = connection # If supported, read which PHY we're connected with before @@ -3203,10 +3299,10 @@ class Device(CompositeEventEmitter): # For directed advertising, this means a timeout if ( transport == BT_LE_TRANSPORT - and self.advertising - and self.advertising_type.is_directed + and self.legacy_advertiser + and self.legacy_advertiser.advertising_type.is_directed ): - self.advertising = False + self.legacy_advertiser = None # Notify listeners error = core.ConnectionError( @@ -3268,16 +3364,30 @@ class Device(CompositeEventEmitter): self.gatt_server.on_disconnection(connection) # Restart advertising if auto-restart is enabled - if self.auto_restart_advertising: + if advertiser := connection.advertiser_after_disconnection: logger.debug('restarting advertising') - self.abort_on( - 'flush', - self.start_advertising( - advertising_type=self.advertising_type, # type: ignore[arg-type] - own_address_type=self.advertising_own_address_type, # type: ignore[arg-type] - auto_restart=True, - ), - ) + if isinstance(advertiser, LegacyAdvertiser): + self.abort_on( + 'flush', + self.start_legacy_advertising( + advertising_type=advertiser.advertising_type, + own_address_type=advertiser.own_address_type, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + auto_restart=True, + ), + ) + elif isinstance(advertiser, ExtendedAdvertiser): + self.abort_on( + 'flush', + self.start_extended_advertising( + advertising_properties=advertiser.advertising_properties, + own_address_type=advertiser.own_address_type, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + auto_restart=True, + ), + ) elif sco_link := self.sco_links.pop(connection_handle, None): sco_link.emit('disconnection', reason) elif cis_link := self.cis_links.pop(connection_handle, None): @@ -3600,6 +3710,30 @@ class Device(CompositeEventEmitter): if sco_link := self.sco_links.get(sco_handle, None): sco_link.emit('pdu', packet) + # [LE only] + @host_event_handler + @experimental('Only for testing') + def on_advertising_set_termination( + self, + status: int, + advertising_handle: int, + connection_handle: int, + ) -> None: + if status == HCI_SUCCESS: + connection = self.lookup_connection(connection_handle) + if advertiser := self.extended_advertisers.pop(advertising_handle, None): + if connection: + if advertiser.auto_restart: + connection.advertiser_after_disconnection = advertiser + if advertiser.own_address_type in ( + OwnAddressType.PUBLIC, + OwnAddressType.RESOLVABLE_OR_PUBLIC, + ): + connection.self_address = self.public_address + else: + connection.self_address = self.random_address + advertiser.emit('termination', status) + # [LE only] @host_event_handler @with_connection_from_handle diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index d8ea06e6..b5712e66 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -19,12 +19,17 @@ like loading firmware after a cold start. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import abc +from __future__ import annotations import logging import pathlib import platform -from . import rtk +from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING +from . import rtk +from .common import Driver + +if TYPE_CHECKING: + from bumble.host import Host # ----------------------------------------------------------------------------- # Logging @@ -32,40 +37,31 @@ from . import rtk logger = logging.getLogger(__name__) -# ----------------------------------------------------------------------------- -# Classes -# ----------------------------------------------------------------------------- -class Driver(abc.ABC): - """Base class for drivers.""" - - @staticmethod - async def for_host(_host): - """Return a driver instance for a host. - - Args: - host: Host object for which a driver should be created. - - Returns: - A Driver instance if a driver should be instantiated for this host, or - None if no driver instance of this class is needed. - """ - return None - - @abc.abstractmethod - async def init_controller(self): - """Initialize the controller.""" - - # ----------------------------------------------------------------------------- # Functions # ----------------------------------------------------------------------------- -async def get_driver_for_host(host): - """Probe all known diver classes until one returns a valid instance for a host, - or none is found. +async def get_driver_for_host(host: Host) -> Optional[Driver]: + """Probe diver classes until one returns a valid instance for a host, or none is + found. + If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - if driver := await rtk.Driver.for_host(host): - logger.debug("Instantiated RTK driver") - return driver + driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver} + probe_list: Iterable[str] + if driver_name := host.hci_metadata.get("driver"): + # Only probe a single driver + probe_list = [driver_name] + else: + # Probe all drivers + probe_list = driver_classes.keys() + + for driver_name in probe_list: + if driver_class := driver_classes.get(driver_name): + logger.debug(f"Probing driver class: {driver_name}") + if driver := await driver_class.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver + else: + logger.debug(f"Skipping unknown driver class: {driver_name}") return None diff --git a/bumble/drivers/common.py b/bumble/drivers/common.py new file mode 100644 index 00000000..a4c0427c --- /dev/null +++ b/bumble/drivers/common.py @@ -0,0 +1,45 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common types for drivers. +""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import abc + + +# ----------------------------------------------------------------------------- +# Classes +# ----------------------------------------------------------------------------- +class Driver(abc.ABC): + """Base class for drivers.""" + + @staticmethod + async def for_host(_host): + """Return a driver instance for a host. + + Args: + host: Host object for which a driver should be created. + + Returns: + A Driver instance if a driver should be instantiated for this host, or + None if no driver instance of this class is needed. + """ + return None + + @abc.abstractmethod + async def init_controller(self): + """Initialize the controller.""" diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index f78a14d3..4a9034db 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -41,7 +41,7 @@ from bumble.hci import ( HCI_Reset_Command, HCI_Read_Local_Version_Information_Command, ) - +from bumble.drivers import common # ----------------------------------------------------------------------------- # Logging @@ -285,7 +285,7 @@ class Firmware: ) -class Driver: +class Driver(common.Driver): @dataclass class DriverInfo: rom: int @@ -470,8 +470,12 @@ class Driver: logger.debug("USB metadata not found") return False - vendor_id = host.hci_metadata.get("vendor_id", None) - product_id = host.hci_metadata.get("product_id", None) + if host.hci_metadata.get('driver') == 'rtk': + # Forced driver + return True + + vendor_id = host.hci_metadata.get("vendor_id") + product_id = host.hci_metadata.get("product_id") if vendor_id is None or product_id is None: logger.debug("USB metadata not sufficient") return False @@ -486,6 +490,9 @@ class Driver: @classmethod async def driver_info_for_host(cls, host): + await host.send_command(HCI_Reset_Command(), check_result=True) + host.ready = True # Needed to let the host know the controller is ready. + response = await host.send_command( HCI_Read_Local_Version_Information_Command(), check_result=True ) diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index cdf1b5e5..eca11ce4 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -961,7 +961,7 @@ class Server(EventEmitter): try: attribute.write_value(connection, request.attribute_value) except Exception as error: - logger.warning(f'!!! ignoring exception: {error}') + logger.exception(f'!!! ignoring exception: {error}') def on_att_handle_value_confirmation(self, connection, _confirmation): ''' diff --git a/bumble/hci.py b/bumble/hci.py index a28246ab..36c049c9 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -728,6 +728,19 @@ HCI_LE_PHY_TYPE_TO_BIT = { HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT } + +class Phy(enum.IntEnum): + LE_1M = 0x01 + LE_2M = 0x02 + LE_CODED = 0x03 + + +class PhyBit(enum.IntFlag): + LE_1M = 0b00000001 + LE_2M = 0b00000010 + LE_CODED = 0b00000100 + + # Connection Parameters HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25 HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25 @@ -1963,25 +1976,15 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_ # ----------------------------------------------------------------------------- -class OwnAddressType: +class OwnAddressType(enum.IntEnum): PUBLIC = 0 RANDOM = 1 RESOLVABLE_OR_PUBLIC = 2 RESOLVABLE_OR_RANDOM = 3 - TYPE_NAMES = { - PUBLIC: 'PUBLIC', - RANDOM: 'RANDOM', - RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC', - RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM', - } - - @staticmethod - def type_name(type_id): - return name_or_number(OwnAddressType.TYPE_NAMES, type_id) - - # pylint: disable-next=unnecessary-lambda - TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)} + @classmethod + def type_spec(cls): + return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name} # ----------------------------------------------------------------------------- @@ -3374,7 +3377,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command): ), }, ), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), ('advertising_channel_map', 1), @@ -3467,7 +3470,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command): ('le_scan_type', 1), ('le_scan_interval', 2), ('le_scan_window', 2), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('scanning_filter_policy', 1), ] ) @@ -3506,7 +3509,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command): ('initiator_filter_policy', 1), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('connection_interval_min', 2), ('connection_interval_max', 2), ('max_latency', 2), @@ -3913,7 +3916,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command): ), }, ), - ('own_address_type', OwnAddressType.TYPE_SPEC), + ('own_address_type', OwnAddressType.type_spec()), ('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address', Address.parse_address_preceded_by_type), ('advertising_filter_policy', 1), @@ -4309,7 +4312,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command): ('initiator_filter_policy:', self.initiator_filter_policy), ( 'own_address_type: ', - OwnAddressType.type_name(self.own_address_type), + OwnAddressType(self.own_address_type).name, ), ( 'peer_address_type: ', @@ -4551,6 +4554,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command): See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command ''' + class Direction(enum.IntEnum): + HOST_TO_CONTROLLER = 0x00 + CONTROLLER_TO_HOST = 0x01 + connection_handle: int data_path_direction: int data_path_id: int @@ -5190,6 +5197,21 @@ HCI_LE_Meta_Event.subevent_classes[ ] = HCI_LE_Extended_Advertising_Report_Event +# ----------------------------------------------------------------------------- +@HCI_LE_Meta_Event.event( + [ + ('status', 1), + ('advertising_handle', 1), + ('connection_handle', 2), + ('number_completed_extended_advertising_events', 1), + ] +) +class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event): + ''' + See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event + ''' + + # ----------------------------------------------------------------------------- @HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)]) class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event): diff --git a/bumble/host.py b/bumble/host.py index b06ceba4..190ab89e 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,7 @@ import collections import logging import struct -from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING from bumble.colors import color from bumble.l2cap import L2CAP_PDU @@ -124,7 +124,8 @@ class Connection: class Host(AbortableEventEmitter): connections: Dict[int, Connection] acl_packet_queue: collections.deque[HCI_AclDataPacket] - hci_sink: TransportSink + hci_sink: Optional[TransportSink] = None + hci_metadata: Dict[str, Any] long_term_key_provider: Optional[ Callable[[int, bytes, int], Awaitable[Optional[bytes]]] ] @@ -137,9 +138,8 @@ class Host(AbortableEventEmitter): ) -> None: super().__init__() - self.hci_metadata = None + self.hci_metadata = {} self.ready = False # True when we can accept incoming packets - self.reset_done = False self.connections = {} # Connections, by connection handle self.pending_command = None self.pending_response = None @@ -162,10 +162,7 @@ class Host(AbortableEventEmitter): # Connect to the source and sink if specified if controller_source: - controller_source.set_packet_sink(self) - self.hci_metadata = getattr( - controller_source, 'metadata', self.hci_metadata - ) + self.set_packet_source(controller_source) if controller_sink: self.set_packet_sink(controller_sink) @@ -200,17 +197,21 @@ class Host(AbortableEventEmitter): self.ready = False await self.flush() - await self.send_command(HCI_Reset_Command(), check_result=True) - self.ready = True - # Instantiate and init a driver for the host if needed. # NOTE: we don't keep a reference to the driver here, because we don't # currently have a need for the driver later on. But if the driver interface # evolves, it may be required, then, to store a reference to the driver in # an object property. + reset_needed = True if driver_factory is not None: if driver := await driver_factory(self): await driver.init_controller() + reset_needed = False + + # Send a reset command unless a driver has already done so. + if reset_needed: + await self.send_command(HCI_Reset_Command(), check_result=True) + self.ready = True response = await self.send_command( HCI_Read_Local_Supported_Commands_Command(), check_result=True @@ -313,25 +314,28 @@ class Host(AbortableEventEmitter): ) ) - self.reset_done = True - @property - def controller(self) -> TransportSink: + def controller(self) -> Optional[TransportSink]: return self.hci_sink @controller.setter - def controller(self, controller): + def controller(self, controller) -> None: self.set_packet_sink(controller) if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink: TransportSink) -> None: + def set_packet_sink(self, sink: Optional[TransportSink]) -> None: self.hci_sink = sink + def set_packet_source(self, source: TransportSource) -> None: + source.set_packet_sink(self) + self.hci_metadata = getattr(source, 'metadata', self.hci_metadata) + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) + if self.hci_sink: + self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') @@ -721,6 +725,14 @@ class Host(AbortableEventEmitter): def on_hci_le_extended_advertising_report_event(self, event): self.on_hci_le_advertising_report_event(event) + def on_hci_le_advertising_set_terminated_event(self, event): + self.emit( + 'advertising_set_termination', + event.status, + event.advertising_handle, + event.connection_handle, + ) + def on_hci_le_cis_request_event(self, event): self.emit( 'cis_request', diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 76015d52..4785997b 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -23,13 +23,21 @@ import dataclasses import enum import struct import functools -from typing import Optional, List, Union +import logging +from typing import Optional, List, Union, Type, Dict, Any, Tuple, cast +from bumble import colors +from bumble import device from bumble import hci from bumble import gatt from bumble import gatt_client +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -220,6 +228,231 @@ class SupportedFrameDuration(enum.IntFlag): DURATION_10000_US_PREFERRED = 0b0010 +# ----------------------------------------------------------------------------- +# ASE Operations +# ----------------------------------------------------------------------------- + + +class ASE_Operation: + ''' + See Audio Stream Control Service - 5 ASE Control operations. + ''' + + classes: Dict[int, Type[ASE_Operation]] = {} + op_code: int + name: str + fields: Optional[Sequence[Any]] = None + ase_id: List[int] + + class Opcode(enum.IntEnum): + # fmt: off + CONFIG_CODEC = 0x01 + CONFIG_QOS = 0x02 + ENABLE = 0x03 + RECEIVER_START_READY = 0x04 + DISABLE = 0x05 + RECEIVER_STOP_READY = 0x06 + UPDATE_METADATA = 0x07 + RELEASE = 0x08 + + @staticmethod + def from_bytes(pdu: bytes) -> ASE_Operation: + op_code = pdu[0] + + cls = ASE_Operation.classes.get(op_code) + if cls is None: + instance = ASE_Operation(pdu) + instance.name = ASE_Operation.Opcode(op_code).name + instance.op_code = op_code + return instance + self = cls.__new__(cls) + ASE_Operation.__init__(self, pdu) + if self.fields is not None: + self.init_from_bytes(pdu, 1) + return self + + @staticmethod + def subclass(fields): + def inner(cls: Type[ASE_Operation]): + try: + operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] + cls.name = operation.name + cls.op_code = operation + except: + raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') + cls.fields = fields + + # Register a factory for this class + ASE_Operation.classes[cls.op_code] = cls + + return cls + + return inner + + def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: + if self.fields is not None and kwargs: + hci.HCI_Object.init_from_fields(self, self.fields, kwargs) + if pdu is None: + pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( + kwargs, self.fields + ) + self.pdu = pdu + + def init_from_bytes(self, pdu: bytes, offset: int): + return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) + + def __bytes__(self) -> bytes: + return self.pdu + + def __str__(self) -> str: + result = f'{colors.color(self.name, "yellow")} ' + if fields := getattr(self, 'fields', None): + result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') + else: + if len(self.pdu) > 1: + result += f': {self.pdu.hex()}' + return result + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('target_latency', 1), + ('target_phy', 1), + ('codec_id', hci.CodingFormat.parse_from_bytes), + ('codec_specific_configuration', 'v'), + ], + ] +) +class ASE_Config_Codec(ASE_Operation): + ''' + See Audio Stream Control Service 5.1 - Config Codec Operation + ''' + + target_latency: List[int] + target_phy: List[int] + codec_id: List[hci.CodingFormat] + codec_specific_configuration: List[bytes] + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('cig_id', 1), + ('cis_id', 1), + ('sdu_interval', 3), + ('framing', 1), + ('phy', 1), + ('max_sdu', 2), + ('retransmission_number', 1), + ('max_transport_latency', 2), + ('presentation_delay', 3), + ], + ] +) +class ASE_Config_QOS(ASE_Operation): + ''' + See Audio Stream Control Service 5.2 - Config Qos Operation + ''' + + cig_id: List[int] + cis_id: List[int] + sdu_interval: List[int] + framing: List[int] + phy: List[int] + max_sdu: List[int] + retransmission_number: List[int] + max_transport_latency: List[int] + presentation_delay: List[int] + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Enable(ASE_Operation): + ''' + See Audio Stream Control Service 5.3 - Enable Operation + ''' + + metadata: bytes + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Start_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.4 - Receiver Start Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Disable(ASE_Operation): + ''' + See Audio Stream Control Service 5.5 - Disable Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Stop_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Update_Metadata(ASE_Operation): + ''' + See Audio Stream Control Service 5.7 - Update Metadata Operation + ''' + + metadata: List[bytes] + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Release(ASE_Operation): + ''' + See Audio Stream Control Service 5.8 - Release Operation + ''' + + +class AseResponseCode(enum.IntEnum): + # fmt: off + SUCCESS = 0x00 + UNSUPPORTED_OPCODE = 0x01 + INVALID_LENGTH = 0x02 + INVALID_ASE_ID = 0x03 + INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 + INVALID_ASE_DIRECTION = 0x05 + UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 + UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 + REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 + INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 + UNSUPPORTED_METADATA = 0x0A + REJECTED_METADATA = 0x0B + INVALID_METADATA = 0x0C + INSUFFICIENT_RESOURCES = 0x0D + UNSPECIFIED_ERROR = 0x0E + + +class AseReasonCode(enum.IntEnum): + # fmt: off + NONE = 0x00 + CODEC_ID = 0x01 + CODEC_SPECIFIC_CONFIGURATION = 0x02 + SDU_INTERVAL = 0x03 + FRAMING = 0x04 + PHY = 0x05 + MAXIMUM_SDU_SIZE = 0x06 + RETRANSMISSION_NUMBER = 0x07 + MAX_TRANSPORT_LATENCY = 0x08 + PRESENTATION_DELAY = 0x09 + INVALID_ASE_CIS_MAPPING = 0x0A + + +class AudioRole(enum.IntEnum): + SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST + SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER + + # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- @@ -325,6 +558,80 @@ class CodecSpecificCapabilities: ) +@dataclasses.dataclass +class CodecSpecificConfiguration: + '''See: + * Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures + * Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements + ''' + + class Type(enum.IntEnum): + # fmt: off + SAMPLING_FREQUENCY = 0x01 + FRAME_DURATION = 0x02 + AUDIO_CHANNEL_ALLOCATION = 0x03 + OCTETS_PER_FRAME = 0x04 + CODEC_FRAMES_PER_SDU = 0x05 + + sampling_frequency: SamplingFrequency + frame_duration: FrameDuration + audio_channel_allocation: AudioLocation + octets_per_codec_frame: int + codec_frames_per_sdu: int + + @classmethod + def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration: + offset = 0 + # Allowed default values. + audio_channel_allocation = AudioLocation.NOT_ALLOWED + codec_frames_per_sdu = 1 + while offset < len(data): + length, type = struct.unpack_from('BB', data, offset) + offset += 2 + value = int.from_bytes(data[offset : offset + length - 1], 'little') + offset += length - 1 + + if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY: + sampling_frequency = SamplingFrequency(value) + elif type == CodecSpecificConfiguration.Type.FRAME_DURATION: + frame_duration = FrameDuration(value) + elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION: + audio_channel_allocation = AudioLocation(value) + elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME: + octets_per_codec_frame = value + elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU: + codec_frames_per_sdu = value + + # It is expected here that if some fields are missing, an error should be raised. + return CodecSpecificConfiguration( + sampling_frequency=sampling_frequency, + frame_duration=frame_duration, + audio_channel_allocation=audio_channel_allocation, + octets_per_codec_frame=octets_per_codec_frame, + codec_frames_per_sdu=codec_frames_per_sdu, + ) + + def __bytes__(self) -> bytes: + return struct.pack( + ' None: + self.service = service + self.ase_id = ase_id + self._state = AseStateMachine.State.IDLE + self.role = role + + uuid = ( + gatt.GATT_SINK_ASE_CHARACTERISTIC + if role == AudioRole.SINK + else gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + super().__init__( + uuid=uuid, + properties=gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.READABLE, + value=gatt.CharacteristicValue(read=self.on_read), + ) + + self.service.device.on('cis_request', self.on_cis_request) + self.service.device.on('cis_establishment', self.on_cis_establishment) + + def on_cis_request( + self, + acl_connection: device.Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ) -> None: + if cis_id == self.cis_id and self.state == self.State.ENABLING: + acl_connection.abort_on( + 'flush', self.service.device.accept_cis_request(cis_handle) + ) + + def on_cis_establishment(self, cis_link: device.CisLink) -> None: + if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: + self.state = self.State.STREAMING + self.cis_link = cis_link + + async def post_cis_established(): + await self.service.device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=cis_link.handle, + data_path_direction=self.role, + data_path_id=0x00, # Fixed HCI + codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ) + ) + await self.service.device.notify_subscribers(self, self.value) + + cis_link.acl_connection.abort_on('flush', post_cis_established()) + + def on_config_codec( + self, + target_latency: int, + target_phy: int, + codec_id: hci.CodingFormat, + codec_specific_configuration: bytes, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + self.State.IDLE, + self.State.CODEC_CONFIGURED, + self.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.max_transport_latency = target_latency + self.phy = target_phy + self.codec_id = codec_id + if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC: + self.codec_specific_configuration = codec_specific_configuration + else: + self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes( + codec_specific_configuration + ) + + self.state = self.State.CODEC_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_config_qos( + self, + cig_id: int, + cis_id: int, + sdu_interval: int, + framing: int, + phy: int, + max_sdu: int, + retransmission_number: int, + max_transport_latency: int, + presentation_delay: int, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.CODEC_CONFIGURED, + AseStateMachine.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.cig_id = cig_id + self.cis_id = cis_id + self.sdu_interval = sdu_interval + self.framing = framing + self.phy = phy + self.max_sdu = max_sdu + self.retransmission_number = retransmission_number + self.max_transport_latency = max_transport_latency + self.presentation_delay = presentation_delay + + self.state = self.State.QOS_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.QOS_CONFIGURED: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.metadata = metadata + self.state = self.State.ENABLING + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.ENABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.STREAMING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.DISABLING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.DISABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.QOS_CONFIGURED + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_update_metadata( + self, metadata: bytes + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.metadata = metadata + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state == AseStateMachine.State.IDLE: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.RELEASING + + async def remove_cis_async(): + await self.service.device.send_command( + hci.HCI_LE_Remove_ISO_Data_Path_Command( + connection_handle=self.cis_link.handle, + data_path_direction=self.role, + ) + ) + self.state = self.State.IDLE + await self.service.device.notify_subscribers(self, self.value) + + self.service.device.abort_on('flush', remove_cis_async()) + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') + self._state = new_state + + @property + def value(self): + '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' + + if self.state == self.State.CODEC_CONFIGURED: + codec_specific_configuration_bytes = bytes( + self.codec_specific_configuration + ) + additional_parameters = ( + struct.pack( + ' bytes: + return self.value + + def __str__(self) -> str: + return ( + f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' + f'state={self._state.name})' + ) + + +class AudioStreamControlService(gatt.TemplateService): + UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE + + ase_state_machines: Dict[int, AseStateMachine] + ase_control_point: gatt.Characteristic + + def __init__( + self, + device: device.Device, + source_ase_id: Sequence[int] = [], + sink_ase_id: Sequence[int] = [], + ) -> None: + self.device = device + self.ase_state_machines = { + **{ + id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) + for id in sink_ase_id + }, + **{ + id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) + for id in source_ase_id + }, + } # ASE state machines, by ASE ID + + self.ase_control_point = gatt.Characteristic( + uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.WRITE + | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.WRITEABLE, + value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), + ) + + super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) + + def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): + if ase := self.ase_state_machines.get(ase_id): + handler = getattr(ase, 'on_' + opcode.name.lower()) + return (ase_id, *handler(*args)) + else: + return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + + def on_write_ase_control_point(self, connection, data): + operation = ASE_Operation.from_bytes(data) + responses = [] + logger.debug(f'*** ASCS Write {operation} ***') + + if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: + for ase_id, *args in zip( + operation.ase_id, + operation.target_latency, + operation.target_phy, + operation.codec_id, + operation.codec_specific_configuration, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: + for ase_id, *args in zip( + operation.ase_id, + operation.cig_id, + operation.cis_id, + operation.sdu_interval, + operation.framing, + operation.phy, + operation.max_sdu, + operation.retransmission_number, + operation.max_transport_latency, + operation.presentation_delay, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.ENABLE, + ASE_Operation.Opcode.UPDATE_METADATA, + ): + for ase_id, *args in zip( + operation.ase_id, + operation.metadata, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.RECEIVER_START_READY, + ASE_Operation.Opcode.DISABLE, + ASE_Operation.Opcode.RECEIVER_STOP_READY, + ASE_Operation.Opcode.RELEASE, + ): + for ase_id in operation.ase_id: + responses.append(self.on_operation(operation.op_code, ase_id, [])) + + control_point_notification = bytes( + [operation.op_code, len(responses)] + ) + b''.join(map(bytes, responses)) + self.device.abort_on( + 'flush', + self.device.notify_subscribers( + self.ase_control_point, control_point_notification + ), + ) + + for ase_id, *_ in responses: + if ase := self.ase_state_machines.get(ase_id): + self.device.abort_on( + 'flush', + self.device.notify_subscribers(ase, ase.value), + ) + + # ----------------------------------------------------------------------------- # Client # ----------------------------------------------------------------------------- @@ -494,3 +1224,24 @@ class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy): gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC ): self.source_audio_locations = characteristics[0] + + +class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = AudioStreamControlService + + sink_ase: List[gatt_client.CharacteristicProxy] + source_ase: List[gatt_client.CharacteristicProxy] + ase_control_point: gatt_client.CharacteristicProxy + + def __init__(self, service_proxy: gatt_client.ServiceProxy): + self.service_proxy = service_proxy + + self.sink_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SINK_ASE_CHARACTERISTIC + ) + self.source_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + self.ase_control_point = service_proxy.get_characteristics_by_uuid( + gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC + )[0] diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index bc0766b2..065e6964 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from typing import Optional from .common import Transport, AsyncPipeSink, SnoopingTransport from ..snoop import create_snooper @@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport: async def open_transport(name: str) -> Transport: """ Open a transport by name. - The name must be : - Where depend on the type (and may be empty for some types). + The name must be : + Where depend on the type (and may be empty for some types), and + is either omitted, or a ,-separated list of = pairs, + enclosed in []. + If there are not metadata or parameter, the : after the may be omitted. + Examples: + * usb:0 + * usb:[driver=rtk]0 + * android-netsim + The supported types are: * serial * udp @@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport: * android-netsim """ - return _wrap_transport(await _open_transport(name)) + scheme, *tail = name.split(':', 1) + spec = tail[0] if tail else None + if spec: + # Metadata may precede the spec + if spec.startswith('['): + metadata_str, *tail = spec[1:].split(']') + spec = tail[0] if tail else None + metadata = dict([entry.split('=') for entry in metadata_str.split(',')]) + else: + metadata = None + + transport = await _open_transport(scheme, spec) + if metadata: + transport.source.metadata = { # type: ignore[attr-defined] + **metadata, + **getattr(transport.source, 'metadata', {}), + } + # pylint: disable=line-too-long + logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined] + + return _wrap_transport(transport) # ----------------------------------------------------------------------------- -async def _open_transport(name: str) -> Transport: +async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: # pylint: disable=import-outside-toplevel # pylint: disable=too-many-return-statements - scheme, *spec = name.split(':', 1) if scheme == 'serial' and spec: from .serial import open_serial_transport - return await open_serial_transport(spec[0]) + return await open_serial_transport(spec) if scheme == 'udp' and spec: from .udp import open_udp_transport - return await open_udp_transport(spec[0]) + return await open_udp_transport(spec) if scheme == 'tcp-client' and spec: from .tcp_client import open_tcp_client_transport - return await open_tcp_client_transport(spec[0]) + return await open_tcp_client_transport(spec) if scheme == 'tcp-server' and spec: from .tcp_server import open_tcp_server_transport - return await open_tcp_server_transport(spec[0]) + return await open_tcp_server_transport(spec) if scheme == 'ws-client' and spec: from .ws_client import open_ws_client_transport - return await open_ws_client_transport(spec[0]) + return await open_ws_client_transport(spec) if scheme == 'ws-server' and spec: from .ws_server import open_ws_server_transport - return await open_ws_server_transport(spec[0]) + return await open_ws_server_transport(spec) if scheme == 'pty': from .pty import open_pty_transport - return await open_pty_transport(spec[0] if spec else None) + return await open_pty_transport(spec) if scheme == 'file': from .file import open_file_transport assert spec is not None - return await open_file_transport(spec[0]) + return await open_file_transport(spec) if scheme == 'vhci': from .vhci import open_vhci_transport - return await open_vhci_transport(spec[0] if spec else None) + return await open_vhci_transport(spec) if scheme == 'hci-socket': from .hci_socket import open_hci_socket_transport - return await open_hci_socket_transport(spec[0] if spec else None) + return await open_hci_socket_transport(spec) if scheme == 'usb': from .usb import open_usb_transport - assert spec is not None - return await open_usb_transport(spec[0]) + assert spec + return await open_usb_transport(spec) if scheme == 'pyusb': from .pyusb import open_pyusb_transport - assert spec is not None - return await open_pyusb_transport(spec[0]) + assert spec + return await open_pyusb_transport(spec) if scheme == 'android-emulator': from .android_emulator import open_android_emulator_transport - return await open_android_emulator_transport(spec[0] if spec else None) + return await open_android_emulator_transport(spec) if scheme == 'android-netsim': from .android_netsim import open_android_netsim_transport - return await open_android_netsim_transport(spec[0] if spec else None) + return await open_android_netsim_transport(spec) raise ValueError('unknown transport scheme') diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 8d19a9e2..9cd7ec21 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: mode = 'host' server_host = 'localhost' server_port = '8554' - if spec is not None: + if spec: params = spec.split(',') for param in params: if param.startswith('mode='): diff --git a/bumble/transport/common.py b/bumble/transport/common.py index ace04da5..f767f54f 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -21,7 +21,7 @@ import struct import asyncio import logging import io -from typing import ContextManager, Tuple, Optional, Protocol, Dict +from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from bumble import hci from bumble.colors import color diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index df9e885a..41250433 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport: ) from error # Compute the adapter index - if spec is None: - adapter_index = 0 - else: - adapter_index = int(spec) + adapter_index = int(spec) if spec else 0 # Bind the socket # NOTE: since Python doesn't support binding with the required address format (yet), diff --git a/docs/mkdocs/src/drivers/index.md b/docs/mkdocs/src/drivers/index.md index a904e006..cb0a981e 100644 --- a/docs/mkdocs/src/drivers/index.md +++ b/docs/mkdocs/src/drivers/index.md @@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly. This may include, for instance, loading a Firmware image or patch, loading a configuration. +By default, drivers will be automatically probed to determine if they should be +used with particular HCI controller. +When the transport for an HCI controller is instantiated from a transport name, +a driver may also be forced by specifying ``driver=`` in the optional +metadata portion of the transport name. For example, +``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the +first USB device, even if a normal probe would not have selected it based on the +USB vendor ID and product ID. + Drivers included in the module are: * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. \ No newline at end of file diff --git a/docs/mkdocs/src/drivers/realtek.md b/docs/mkdocs/src/drivers/realtek.md index acbce490..599ce048 100644 --- a/docs/mkdocs/src/drivers/realtek.md +++ b/docs/mkdocs/src/drivers/realtek.md @@ -1,13 +1,16 @@ REALTEK DRIVER ============== -This driver supports loading firmware images and optional config data to +This driver supports loading firmware images and optional config data to USB dongles with a Realtek chipset. A number of USB dongles are supported, but likely not all. -When using a USB dongle, the USB product ID and manufacturer ID are used +When using a USB dongle, the USB product ID and vendor ID are used to find whether a matching set of firmware image and config data is needed for that specific model. If a match exists, the driver will try load the firmware image and, if needed, config data. +Alternatively, the metadata property ``driver=rtk`` may be specified in a transport +name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just +``usb:0`` for the first USB device). The driver will look for those files by name, in order, in: * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR` diff --git a/examples/leaudio.json b/examples/leaudio.json index 4b6edfce..c4c5a11c 100644 --- a/examples/leaudio.json +++ b/examples/leaudio.json @@ -1,5 +1,6 @@ { "name": "Bumble-LEA", "keystore": "JsonKeyStore", + "address": "F0:F1:F2:F3:F4:FA", "advertising_interval": 100 } diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 868b4f84..e71cbeff 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -19,12 +19,14 @@ import asyncio import logging import sys import os +import struct from bumble.core import AdvertisingData -from bumble.device import Device +from bumble.device import Device, CisLink from bumble.hci import ( CodecID, CodingFormat, OwnAddressType, + HCI_IsoDataPacket, HCI_LE_Set_Extended_Advertising_Parameters_Command, ) from bumble.profiles.bap import ( @@ -35,6 +37,7 @@ from bumble.profiles.bap import ( SupportedFrameDuration, PacRecord, PublishedAudioCapabilitiesService, + AudioStreamControlService, ) from bumble.transport import open_transport_or_link @@ -103,6 +106,8 @@ async def main() -> None: ) ) + device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) + advertising_data = bytes( AdvertisingData( [ @@ -110,6 +115,16 @@ async def main() -> None: AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble LE Audio', 'utf-8'), ), + ( + AdvertisingData.FLAGS, + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), + ), ( AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(PublishedAudioCapabilitiesService.UUID), @@ -117,6 +132,41 @@ async def main() -> None: ] ) ) + subprocess = await asyncio.create_subprocess_shell( + f'dlc3 | ffplay pipe:0', + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdin = subprocess.stdin + assert stdin + + # Write a fake LC3 header to dlc3. + stdin.write( + bytes([0x1C, 0xCC]) # Header. + + struct.pack( + '= Build.VERSION_CODES.TIRAMISU) { + bluetoothAdapter.getRemoteLeDevice( + address, + if (addressIsPublic) { + BluetoothDevice.ADDRESS_TYPE_PUBLIC + } else { + BluetoothDevice.ADDRESS_TYPE_RANDOM + } + ) + } else { + bluetoothAdapter.getRemoteDevice(address) + } + + val gatt = remoteDevice.connectGatt( + context, + false, + object : BluetoothGattCallback() { + override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) { + Log.info("MTU update: mtu=$mtu status=$status") + viewModel.mtu = mtu + } + + override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) { + Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status") + viewModel.txPhy = txPhy + viewModel.rxPhy = rxPhy + } + + override fun onConnectionStateChange( + gatt: BluetoothGatt?, status: Int, newState: Int + ) { + if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { + gatt.setPreferredPhy( + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_LE_2M_MASK, + BluetoothDevice.PHY_OPTION_NO_PREFERRED + ) + gatt.readPhy() + } + } + }, + BluetoothDevice.TRANSPORT_LE, + if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK + ) + val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm) val client = SocketClient(viewModel, socket) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt index 79c70045..76c297b3 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/L2capServer.kt @@ -30,7 +30,7 @@ private val Log = Logger.getLogger("btbench.l2cap-server") class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdapter: BluetoothAdapter) { @SuppressLint("MissingPermission") fun run() { - // Advertise to that the peer can find us and connect. + // Advertise so that the peer can find us and connect. val callback = object: AdvertiseCallback() { override fun onStartFailure(errorCode: Int) { Log.warning("failed to start advertising: $errorCode") @@ -50,13 +50,12 @@ class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdap val advertiseData = AdvertiseData.Builder().build() val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build() val advertiser = bluetoothAdapter.bluetoothLeAdvertiser - advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel() viewModel.l2capPsm = serverSocket.psm Log.info("psm = $serverSocket.psm") val server = SocketServer(viewModel, serverSocket) - server.run({ advertiser.stopAdvertising(callback) }) + server.run({ advertiser.stopAdvertising(callback) }, { advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) }) } } \ No newline at end of file diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt index 314f7465..60818371 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/MainActivity.kt @@ -26,23 +26,33 @@ import android.os.Bundle import androidx.activity.ComponentActivity import androidx.activity.compose.setContent import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Button import androidx.compose.material3.Divider import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Slider import androidx.compose.material3.Surface +import androidx.compose.material3.Switch import androidx.compose.material3.Text import androidx.compose.material3.TextField import androidx.compose.runtime.Composable +import androidx.compose.runtime.remember +import androidx.compose.ui.Alignment import androidx.compose.ui.ExperimentalComposeUiApi import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.FocusRequester +import androidx.compose.ui.focus.focusRequester +import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalSoftwareKeyboardController import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.input.ImeAction @@ -171,7 +181,7 @@ class MainActivity : ComponentActivity() { } private fun runL2capClient() { - val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it) } + val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it, baseContext) } l2capClient?.run() } @@ -199,9 +209,12 @@ fun MainView( runL2capServer: () -> Unit ) { BTBenchTheme { - // A surface container using the 'background' color from the theme + val scrollState = rememberScrollState() Surface( - modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState), + color = MaterialTheme.colorScheme.background ) { Column(modifier = Modifier.padding(horizontal = 16.dp)) { Text( @@ -212,28 +225,33 @@ fun MainView( ) Divider() val keyboardController = LocalSoftwareKeyboardController.current - TextField(label = { - Text(text = "Peer Bluetooth Address") - }, + val focusRequester = remember { FocusRequester() } + val focusManager = LocalFocusManager.current + TextField( + label = { + Text(text = "Peer Bluetooth Address") + }, value = appViewModel.peerBluetoothAddress, - modifier = Modifier.fillMaxWidth(), + modifier = Modifier.fillMaxWidth().focusRequester(focusRequester), keyboardOptions = KeyboardOptions.Default.copy( keyboardType = KeyboardType.Ascii, imeAction = ImeAction.Done ), onValueChange = { appViewModel.updatePeerBluetoothAddress(it) }, - keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() }) + keyboardActions = KeyboardActions(onDone = { + keyboardController?.hide() + focusManager.clearFocus() + }) ) Divider() TextField(label = { Text(text = "L2CAP PSM") }, value = appViewModel.l2capPsm.toString(), - modifier = Modifier.fillMaxWidth(), + modifier = Modifier.fillMaxWidth().focusRequester(focusRequester), keyboardOptions = KeyboardOptions.Default.copy( - keyboardType = KeyboardType.Number, - imeAction = ImeAction.Done + keyboardType = KeyboardType.Number, imeAction = ImeAction.Done ), onValueChange = { if (it.isNotEmpty()) { @@ -243,7 +261,11 @@ fun MainView( } } }, - keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() })) + keyboardActions = KeyboardActions(onDone = { + keyboardController?.hide() + focusManager.clearFocus() + }) + ) Divider() Slider( value = appViewModel.senderPacketCountSlider, onValueChange = { @@ -264,7 +286,19 @@ fun MainView( ActionButton( text = "Become Discoverable", onClick = becomeDiscoverable, true ) - Row() { + Row( + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text(text = "2M PHY") + Spacer(modifier = Modifier.padding(start = 8.dp)) + Switch( + checked = appViewModel.use2mPhy, + onCheckedChange = { appViewModel.use2mPhy = it } + ) + + } + Row { ActionButton( text = "RFCOMM Client", onClick = runRfcommClient, !appViewModel.running ) @@ -272,7 +306,7 @@ fun MainView( text = "RFCOMM Server", onClick = runRfcommServer, !appViewModel.running ) } - Row() { + Row { ActionButton( text = "L2CAP Client", onClick = runL2capClient, !appViewModel.running ) @@ -281,6 +315,12 @@ fun MainView( ) } Divider() + Text( + text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else "" + ) + Text( + text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else "" + ) Text( text = "Packets Sent: ${appViewModel.packetsSent}" ) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt index 93755e40..35ee8da5 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/Model.kt @@ -32,6 +32,10 @@ class AppViewModel : ViewModel() { private var preferences: SharedPreferences? = null var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS) var l2capPsm by mutableStateOf(0) + var use2mPhy by mutableStateOf(true) + var mtu by mutableStateOf(0) + var rxPhy by mutableStateOf(0) + var txPhy by mutableStateOf(0) var senderPacketCountSlider by mutableFloatStateOf(0.0F) var senderPacketSizeSlider by mutableFloatStateOf(0.0F) var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT) @@ -64,11 +68,12 @@ class AppViewModel : ViewModel() { } fun updatePeerBluetoothAddress(peerBluetoothAddress: String) { - this.peerBluetoothAddress = peerBluetoothAddress + val address = peerBluetoothAddress.uppercase() + this.peerBluetoothAddress = address // Save the address to the preferences with(preferences!!.edit()) { - putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, peerBluetoothAddress) + putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, address) apply() } } @@ -116,7 +121,7 @@ class AppViewModel : ViewModel() { } fun updateSenderPacketSizeSlider() { - if (senderPacketSize <= 1) { + if (senderPacketSize <= 16) { senderPacketSizeSlider = 0.0F } else if (senderPacketSize <= 256) { senderPacketSizeSlider = 0.02F @@ -138,7 +143,7 @@ class AppViewModel : ViewModel() { fun updateSenderPacketSize() { if (senderPacketSizeSlider < 0.1F) { - senderPacketSize = 1 + senderPacketSize = 16 } else if (senderPacketSizeSlider < 0.3F) { senderPacketSize = 256 } else if (senderPacketSizeSlider < 0.5F) { diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt index 644a5bda..e976c429 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommClient.kt @@ -25,7 +25,8 @@ private val Log = Logger.getLogger("btbench.rfcomm-client") class RfcommClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) { @SuppressLint("MissingPermission") fun run() { - val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress) + val address = viewModel.peerBluetoothAddress.take(17) + val remoteDevice = bluetoothAdapter.getRemoteDevice(address) val socket = remoteDevice.createInsecureRfcommSocketToServiceRecord( DEFAULT_RFCOMM_UUID ) diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt index f06736b2..69612c55 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/RfcommServer.kt @@ -30,6 +30,6 @@ class RfcommServer(private val viewModel: AppViewModel, val bluetoothAdapter: Bl ) val server = SocketServer(viewModel, serverSocket) - server.run({}) + server.run({}, {}) } } \ No newline at end of file diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt index cc5058e4..bd5b7f4a 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketClient.kt @@ -22,6 +22,8 @@ import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.socket-client") +private const val DEFAULT_STARTUP_DELAY = 3000 + class SocketClient(private val viewModel: AppViewModel, private val socket: BluetoothSocket) { @SuppressLint("MissingPermission") fun run() { @@ -56,6 +58,10 @@ class SocketClient(private val viewModel: AppViewModel, private val socket: Blue socketDataSource.receive() } + Log.info("Startup delay: $DEFAULT_STARTUP_DELAY") + Thread.sleep(DEFAULT_STARTUP_DELAY.toLong()); + Log.info("Starting to send") + sender.run() cleanup() } diff --git a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt index 3f9c3e1d..e83a47f2 100644 --- a/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt +++ b/extras/android/BtBench/app/src/main/java/com/github/google/bumble/btbench/SocketServer.kt @@ -22,14 +22,13 @@ import kotlin.concurrent.thread private val Log = Logger.getLogger("btbench.socket-server") class SocketServer(private val viewModel: AppViewModel, private val serverSocket: BluetoothServerSocket) { - fun run(onTerminate: () -> Unit) { + fun run(onConnected: () -> Unit, onDisconnected: () -> Unit) { var aborted = false viewModel.running = true fun cleanup() { serverSocket.close() viewModel.running = false - onTerminate() } thread(name = "SocketServer") { @@ -38,6 +37,7 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket serverSocket.close() } Log.info("waiting for connection...") + onDisconnected() val socket = try { serverSocket.accept() } catch (error: IOException) { @@ -45,7 +45,8 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket cleanup() return@thread } - Log.info("got connection") + Log.info("got connection from ${socket.remoteDevice.address}") + onConnected() viewModel.aborter = { aborted = true diff --git a/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt b/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt index ebdc708a..493b7e50 100644 --- a/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt +++ b/extras/android/RemoteHCI/app/src/main/java/com/github/google/bumble/remotehci/MainActivity.kt @@ -10,8 +10,10 @@ import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Button import androidx.compose.material3.Divider import androidx.compose.material3.ExperimentalMaterial3Api @@ -71,7 +73,7 @@ class AppViewModel : ViewModel(), HciProxy.Listener { this.tcpPort = tcpPort // Save the port to the preferences - with (preferences!!.edit()) { + with(preferences!!.edit()) { putString(TCP_PORT_PREF_KEY, tcpPort.toString()) apply() } @@ -138,7 +140,8 @@ class MainActivity : ComponentActivity() { log.warning("Exception while running HCI Server: $error") } catch (error: HalException) { log.warning("HAL exception: ${error.message}") - appViewModel.message = "Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell." + appViewModel.message = + "Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell." } log.info("HCI Proxy thread ended") appViewModel.canStart = true @@ -157,9 +160,12 @@ fun ActionButton(text: String, onClick: () -> Unit, enabled: Boolean) { @Composable fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { RemoteHCITheme { - // A surface container using the 'background' color from the theme + val scrollState = rememberScrollState() Surface( - modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState), + color = MaterialTheme.colorScheme.background ) { Column(modifier = Modifier.padding(horizontal = 16.dp)) { Text( @@ -174,13 +180,15 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { ) Divider() val keyboardController = LocalSoftwareKeyboardController.current - TextField( - label = { - Text(text = "TCP Port") - }, + TextField(label = { + Text(text = "TCP Port") + }, value = appViewModel.tcpPort.toString(), modifier = Modifier.fillMaxWidth(), - keyboardOptions = KeyboardOptions.Default.copy(keyboardType = KeyboardType.Number, imeAction = ImeAction.Done), + keyboardOptions = KeyboardOptions.Default.copy( + keyboardType = KeyboardType.Number, + imeAction = ImeAction.Done + ), onValueChange = { if (it.isNotEmpty()) { val tcpPort = it.toIntOrNull() @@ -189,10 +197,7 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) { } } }, - keyboardActions = KeyboardActions( - onDone = {keyboardController?.hide()} - ) - ) + keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() })) Divider() val connectState = if (appViewModel.hostConnected) "CONNECTED" else "DISCONNECTED" Text( diff --git a/tests/bap_test.py b/tests/bap_test.py index 01fc568e..d9d12596 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- import asyncio import os +import functools import pytest import logging @@ -24,11 +25,26 @@ from bumble import device from bumble.hci import CodecID, CodingFormat from bumble.profiles.bap import ( AudioLocation, + AseStateMachine, + ASE_Operation, + ASE_Config_Codec, + ASE_Config_QOS, + ASE_Disable, + ASE_Enable, + ASE_Receiver_Start_Ready, + ASE_Receiver_Stop_Ready, + ASE_Release, + ASE_Update_Metadata, SupportedFrameDuration, SupportedSamplingFrequency, + SamplingFrequency, + FrameDuration, CodecSpecificCapabilities, + CodecSpecificConfiguration, ContextType, PacRecord, + AudioStreamControlService, + AudioStreamControlServiceProxy, PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesServiceProxy, ) @@ -40,6 +56,13 @@ from .test_utils import TwoDevices logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +def basic_check(operation: ASE_Operation): + serialized = bytes(operation) + parsed = ASE_Operation.from_bytes(serialized) + assert bytes(parsed) == serialized + + # ----------------------------------------------------------------------------- def test_codec_specific_capabilities() -> None: SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000 @@ -85,6 +108,92 @@ def test_vendor_specific_pac_record() -> None: assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA +# ----------------------------------------------------------------------------- +def test_ASE_Config_Codec() -> None: + operation = ASE_Config_Codec( + ase_id=[1, 2], + target_latency=[3, 4], + target_phy=[5, 6], + codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)], + codec_specific_configuration=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Config_QOS() -> None: + operation = ASE_Config_QOS( + ase_id=[1, 2], + cig_id=[1, 2], + cis_id=[3, 4], + sdu_interval=[5, 6], + framing=[0, 1], + phy=[2, 3], + max_sdu=[4, 5], + retransmission_number=[6, 7], + max_transport_latency=[8, 9], + presentation_delay=[10, 11], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Enable() -> None: + operation = ASE_Enable( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Update_Metadata() -> None: + operation = ASE_Update_Metadata( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Disable() -> None: + operation = ASE_Disable(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Release() -> None: + operation = ASE_Release(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Receiver_Start_Ready() -> None: + operation = ASE_Receiver_Start_Ready(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_ASE_Receiver_Stop_Ready() -> None: + operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2]) + basic_check(operation) + + +# ----------------------------------------------------------------------------- +def test_codec_specific_configuration() -> None: + SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000 + FRAME_SURATION = FrameDuration.DURATION_10000_US + AUDIO_LOCATION = AudioLocation.FRONT_LEFT + config = CodecSpecificConfiguration( + sampling_frequency=SAMPLE_FREQUENCY, + frame_duration=FRAME_SURATION, + audio_channel_allocation=AUDIO_LOCATION, + octets_per_codec_frame=60, + codec_frames_per_sdu=1, + ) + assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_pacs(): @@ -140,6 +249,148 @@ async def test_pacs(): ) +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_ascs(): + devices = TwoDevices() + devices[0].add_service( + AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2]) + ) + + await devices.setup_connection() + peer = device.Peer(devices.connections[1]) + ascs_client = await peer.discover_service_and_create_proxy( + AudioStreamControlServiceProxy + ) + + notifications = {1: asyncio.Queue(), 2: asyncio.Queue()} + + def on_notification(data: bytes, ase_id: int): + notifications[ase_id].put_nowait(data) + + # Should be idle + assert await ascs_client.sink_ase[0].read_value() == bytes( + [1, AseStateMachine.State.IDLE] + ) + assert await ascs_client.sink_ase[1].read_value() == bytes( + [2, AseStateMachine.State.IDLE] + ) + + # Subscribe + await ascs_client.sink_ase[0].subscribe( + functools.partial(on_notification, ase_id=1) + ) + await ascs_client.sink_ase[1].subscribe( + functools.partial(on_notification, ase_id=2) + ) + + # Config Codec + config = CodecSpecificConfiguration( + sampling_frequency=SamplingFrequency.FREQ_48000, + frame_duration=FrameDuration.DURATION_10000_US, + audio_channel_allocation=AudioLocation.FRONT_LEFT, + octets_per_codec_frame=120, + codec_frames_per_sdu=1, + ) + await ascs_client.ase_control_point.write_value( + ASE_Config_Codec( + ase_id=[1, 2], + target_latency=[3, 4], + target_phy=[5, 6], + codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)], + codec_specific_configuration=[config, config], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.CODEC_CONFIGURED] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.CODEC_CONFIGURED] + ) + + # Config QOS + await ascs_client.ase_control_point.write_value( + ASE_Config_QOS( + ase_id=[1, 2], + cig_id=[1, 2], + cis_id=[3, 4], + sdu_interval=[5, 6], + framing=[0, 1], + phy=[2, 3], + max_sdu=[4, 5], + retransmission_number=[6, 7], + max_transport_latency=[8, 9], + presentation_delay=[10, 11], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.QOS_CONFIGURED] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.QOS_CONFIGURED] + ) + + # Enable + await ascs_client.ase_control_point.write_value( + ASE_Enable( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.ENABLING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.ENABLING] + ) + + # CIS establishment + devices[0].emit( + 'cis_establishment', + device.CisLink( + device=devices[0], + acl_connection=devices.connections[0], + handle=5, + cis_id=3, + cig_id=1, + ), + ) + devices[0].emit( + 'cis_establishment', + device.CisLink( + device=devices[0], + acl_connection=devices.connections[0], + handle=6, + cis_id=4, + cig_id=2, + ), + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.STREAMING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.STREAMING] + ) + + # Release + await ascs_client.ase_control_point.write_value( + ASE_Release( + ase_id=[1, 2], + metadata=[b'foo', b'bar'], + ) + ) + assert (await notifications[1].get())[:2] == bytes( + [1, AseStateMachine.State.RELEASING] + ) + assert (await notifications[2].get())[:2] == bytes( + [2, AseStateMachine.State.RELEASING] + ) + assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE]) + assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE]) + + await asyncio.sleep(0.001) + + # ----------------------------------------------------------------------------- async def run(): await test_pacs() diff --git a/tests/device_test.py b/tests/device_test.py index 1bcd0d08..d51431f5 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -20,8 +20,14 @@ import logging import os from types import LambdaType import pytest +from unittest import mock -from bumble.core import BT_BR_EDR_TRANSPORT +from bumble.core import ( + BT_BR_EDR_TRANSPORT, + BT_LE_TRANSPORT, + BT_PERIPHERAL_ROLE, + ConnectionParameters, +) from bumble.device import Connection, Device from bumble.host import Host from bumble.hci import ( @@ -30,6 +36,7 @@ from bumble.hci import ( HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS, Address, + OwnAddressType, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, @@ -232,6 +239,172 @@ async def test_flush(): pass +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_legacy_advertising(): + device = Device(host=mock.AsyncMock(Host)) + + # Start advertising + advertiser = await device.start_legacy_advertising() + assert device.legacy_advertiser + + # Stop advertising + await advertiser.stop() + assert not device.legacy_advertiser + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'own_address_type,', + (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), +) +@pytest.mark.asyncio +async def test_legacy_advertising_connection(own_address_type): + device = Device(host=mock.AsyncMock(Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + + # Start advertising + advertiser = await device.start_legacy_advertising() + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + + if own_address_type == OwnAddressType.PUBLIC: + assert device.lookup_connection(0x0001).self_address == device.public_address + else: + assert device.lookup_connection(0x0001).self_address == device.random_address + + # For unknown reason, read_phy() in on_connection() would be killed at the end of + # test, so we force scheduling here to avoid an warning. + await asyncio.sleep(0.0001) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'auto_restart,', + (True, False), +) +@pytest.mark.asyncio +async def test_legacy_advertising_disconnection(auto_restart): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_legacy_advertising(auto_restart=auto_restart) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + + device.start_legacy_advertising = mock.AsyncMock() + + device.on_disconnection(0x0001, 0) + + if auto_restart: + device.start_legacy_advertising.assert_called_with( + advertising_type=advertiser.advertising_type, + own_address_type=advertiser.own_address_type, + auto_restart=advertiser.auto_restart, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + ) + else: + device.start_legacy_advertising.assert_not_called() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_extended_advertising(): + device = Device(host=mock.AsyncMock(Host)) + + # Start advertising + advertiser = await device.start_extended_advertising() + assert device.extended_advertisers + + # Stop advertising + await advertiser.stop() + assert not device.extended_advertisers + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'own_address_type,', + (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), +) +@pytest.mark.asyncio +async def test_extended_advertising_connection(own_address_type): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_extended_advertising( + own_address_type=own_address_type + ) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + device.on_advertising_set_termination( + HCI_SUCCESS, + advertiser.handle, + 0x0001, + ) + + if own_address_type == OwnAddressType.PUBLIC: + assert device.lookup_connection(0x0001).self_address == device.public_address + else: + assert device.lookup_connection(0x0001).self_address == device.random_address + + # For unknown reason, read_phy() in on_connection() would be killed at the end of + # test, so we force scheduling here to avoid an warning. + await asyncio.sleep(0.0001) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'auto_restart,', + (True, False), +) +@pytest.mark.asyncio +async def test_extended_advertising_disconnection(auto_restart): + device = Device(host=mock.AsyncMock(spec=Host)) + peer_address = Address('F0:F1:F2:F3:F4:F5') + advertiser = await device.start_extended_advertising(auto_restart=auto_restart) + device.on_connection( + 0x0001, + BT_LE_TRANSPORT, + peer_address, + BT_PERIPHERAL_ROLE, + ConnectionParameters(0, 0, 0), + ) + device.on_advertising_set_termination( + HCI_SUCCESS, + advertiser.handle, + 0x0001, + ) + + device.start_extended_advertising = mock.AsyncMock() + + device.on_disconnection(0x0001, 0) + + if auto_restart: + device.start_extended_advertising.assert_called_with( + advertising_properties=advertiser.advertising_properties, + own_address_type=advertiser.own_address_type, + auto_restart=advertiser.auto_restart, + advertising_data=advertiser.advertising_data, + scan_response_data=advertiser.scan_response_data, + ) + else: + device.start_extended_advertising.assert_not_called() + + # ----------------------------------------------------------------------------- def test_gatt_services_with_gas(): device = Device(host=Host(None, None))