Compare commits

..

26 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod b4af46ebd5 use TCP_NODELAY on socket 2023-12-27 12:11:20 -08:00
Gilles Boccon-Gibod c08da3193e format 2023-12-27 11:56:06 -08:00
Gilles Boccon-Gibod fd4d68e5c0 print controller flow control info 2023-12-26 13:24:24 -08:00
Gilles Boccon-Gibod b90d0f8710 fix tests 2023-12-26 09:09:20 -08:00
Gilles Boccon-Gibod afc6d19e04 address PR comments 2023-12-23 14:21:44 -08:00
Gilles Boccon-Gibod c05f073b33 Update bumble/host.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2023-12-23 14:15:53 -08:00
Gilles Boccon-Gibod 2b4c2a22f4 format 2023-12-22 14:22:08 -08:00
Gilles Boccon-Gibod 47fe93a148 support per-transport ACL queues 2023-12-22 13:52:33 -08:00
Gilles Boccon-Gibod a286700239 Merge pull request #368 from google/gbg/driver-load-before-reset
support drivers that can't use reset directly.
2023-12-11 18:06:23 -08:00
Gilles Boccon-Gibod 98ed772e8a address PR comments and add some typing 2023-12-11 17:52:04 -08:00
Gilles Boccon-Gibod f0b55a4f97 Merge pull request #367 from google/gbg/android-bench-update
Android bench app: add support for 2M phy
2023-12-11 10:20:56 -08:00
zxzxwu b74503d345 Merge pull request #359 from zxzxwu/ascs
Audio Stream Control Service
2023-12-12 00:47:03 +08:00
Josh Wu f911163e49 Improve ASCS logging 2023-12-12 00:36:24 +08:00
Gilles Boccon-Gibod b083cc99ad fix spec parsing 2023-12-08 18:57:02 -08:00
Gilles Boccon-Gibod 62a8ced447 support drivers that can't use reset directly. 2023-12-08 17:28:57 -08:00
Josh Wu 81a6b1e097 Replace 3.9 dict merger 2023-12-08 11:10:17 +08:00
Josh Wu dd090c9e6b Add ASCS tests 2023-12-08 11:00:44 +08:00
Josh Wu 11faa48422 Fix ASE state change 2023-12-08 09:53:14 +08:00
Josh Wu 55596176c2 ffplay routing 2023-12-08 09:53:14 +08:00
Josh Wu 4d6822d312 Remove ISO data path on release 2023-12-08 09:53:14 +08:00
Josh Wu 985c365e6d Setup data path after CIS established 2023-12-08 09:53:14 +08:00
Josh Wu af57762227 Parse CodecSpecificConfiguration 2023-12-08 09:53:14 +08:00
Josh Wu 3575f9030e Add Audio Stream Control Service 2023-12-08 09:53:14 +08:00
zxzxwu 698d947d85 Merge pull request #366 from zxzxwu/extadv
Add advertiser classes and handle adv set terminated events
2023-12-08 09:52:42 +08:00
Josh Wu ff6528d2bf Add Advertising unit tests 2023-12-08 01:38:01 +08:00
Josh Wu 72ac75a98d Add advertiser classes and handle adv set terminated events
* Convert hci.OwnAddressType to enum
* Add LegacyAdvertiser and ExtendedAdvertiser classes
* Rename start/stop_advertising() => start/stop_legacy_advertising()
* Handle HCI_Advertising_Set_Terminated
* Properly restart advertisement on disconnection
2023-12-07 15:51:51 +08:00
25 changed files with 1851 additions and 250 deletions
+21 -2
View File
@@ -82,10 +82,11 @@ SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 1234
DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1022
DEFAULT_L2CAP_MPS = 1024
DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1022
DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
DEFAULT_RFCOMM_CHANNEL = 8
@@ -952,6 +953,10 @@ class Central(Connection.Listener):
await self.device.power_on()
if self.classic:
await self.device.set_discoverable(False)
await self.device.set_connectable(False)
print(color(f'### Connecting to {self.peripheral_address}...', 'cyan'))
try:
self.connection = await self.device.connect(
@@ -972,6 +977,11 @@ class Central(Connection.Listener):
self.connection.listener = self
print_connection(self.connection)
# Wait a bit after the connection, some controllers aren't very good when
# we start sending data right away while some connection parameters are
# updated post connection
await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)
# Request a new data length if requested
if self.extended_data_length:
print(color('+++ Requesting extended data length', 'cyan'))
@@ -1098,6 +1108,15 @@ class Peripheral(Device.Listener, Connection.Listener):
self.connection = connection
self.connected.set()
# Stop being discoverable and connectable
if self.classic:
async def stop_being_discoverable_connectable():
await self.device.set_discoverable(False)
await self.device.set_connectable(False)
AsyncRunner.spawn(stop_being_discoverable_connectable())
# Request a new data length if needed
if self.extended_data_length:
print("+++ Requesting extended data length")
+34 -2
View File
@@ -32,10 +32,14 @@ from bumble.hci import (
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
@@ -59,7 +63,7 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_classic_info(host):
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
@@ -80,7 +84,7 @@ async def get_classic_info(host):
# -----------------------------------------------------------------------------
async def get_le_info(host):
async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -136,6 +140,31 @@ async def get_le_info(host):
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(transport):
print('<<< connecting to HCI...')
@@ -168,6 +197,9 @@ async def async_main(transport):
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
+28 -1
View File
@@ -134,12 +134,14 @@ class Controller:
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27
self.hc_total_num_data_packets = 64
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.event_mask = 0
self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_event_mask = 0
@@ -914,6 +916,19 @@ class Controller:
'''
return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
'''
return struct.pack(
'<BHBHH',
HCI_SUCCESS,
self.hc_data_packet_length,
0,
self.hc_total_num_data_packets,
0,
)
def on_hci_read_bd_addr_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
@@ -1263,3 +1278,15 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
def on_hci_le_setup_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
+180 -46
View File
@@ -437,6 +437,38 @@ class AdvertisingType(IntEnum):
)
# -----------------------------------------------------------------------------
@dataclass
class LegacyAdvertiser:
device: Device
advertising_type: AdvertisingType
own_address_type: OwnAddressType
auto_restart: bool
advertising_data: Optional[bytes]
scan_response_data: Optional[bytes]
async def stop(self) -> None:
await self.device.stop_legacy_advertising()
# -----------------------------------------------------------------------------
@dataclass
class ExtendedAdvertiser(CompositeEventEmitter):
device: Device
handle: int
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties
own_address_type: OwnAddressType
auto_restart: bool
advertising_data: Optional[bytes]
scan_response_data: Optional[bytes]
def __post_init__(self) -> None:
super().__init__()
async def stop(self) -> None:
await self.device.stop_extended_advertising(self.handle)
# -----------------------------------------------------------------------------
class LePhyOptions:
# Coded PHY preference
@@ -658,6 +690,9 @@ class Connection(CompositeEventEmitter):
gatt_client: gatt_client.Client
pairing_peer_io_capability: Optional[int]
pairing_peer_authentication_requirements: Optional[int]
advertiser_after_disconnection: Union[
LegacyAdvertiser, ExtendedAdvertiser, None
] = None
@composite_listener
class Listener:
@@ -1063,7 +1098,8 @@ class Device(CompositeEventEmitter):
]
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
config: DeviceConfiguration
extended_advertising_handles: Set[int]
legacy_advertiser: Optional[LegacyAdvertiser]
extended_advertisers: Dict[int, ExtendedAdvertiser]
sco_links: Dict[int, ScoLink]
cis_links: Dict[int, CisLink]
_pending_cis: Dict[int, Tuple[int, int]]
@@ -1141,10 +1177,7 @@ class Device(CompositeEventEmitter):
self._host = None
self.powered_on = False
self.advertising = False
self.advertising_type = None
self.auto_restart_inquiry = True
self.auto_restart_advertising = False
self.command_timeout = 10 # seconds
self.gatt_server = gatt_server.Server(self)
self.sdp_server = sdp.Server(self)
@@ -1168,10 +1201,10 @@ class Device(CompositeEventEmitter):
self.classic_pending_accepts = {
Address.ANY: []
} # Futures, by BD address OR [Futures] for Address.ANY
self.extended_advertising_handles = set()
self.legacy_advertiser = None
self.extended_advertisers = {}
# Own address type cache
self.advertising_own_address_type = None
self.connect_own_address_type = None
# Use the initial config or a default
@@ -1579,6 +1612,7 @@ class Device(CompositeEventEmitter):
return self.host.supports_le_feature(feature_map[phy])
@deprecated("Please use start_legacy_advertising.")
async def start_advertising(
self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
@@ -1586,15 +1620,49 @@ class Device(CompositeEventEmitter):
own_address_type: int = OwnAddressType.RANDOM,
auto_restart: bool = False,
) -> None:
await self.start_legacy_advertising(
advertising_type=advertising_type,
target=target,
own_address_type=OwnAddressType(own_address_type),
auto_restart=auto_restart,
)
async def start_legacy_advertising(
self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target: Optional[Address] = None,
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
auto_restart: bool = False,
advertising_data: Optional[bytes] = None,
scan_response_data: Optional[bytes] = None,
) -> LegacyAdvertiser:
"""Starts an legacy advertisement.
Args:
advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command.
target: Directed advertising target. Directed type should be set in advertising_type arg.
own_address_type: own address type to use in the advertising.
auto_restart: whether the advertisement will be restarted after disconnection.
scan_response_data: raw scan response.
advertising_data: raw advertising data.
Returns:
LegacyAdvertiser object containing the metadata of advertisement.
"""
if self.extended_advertisers:
logger.warning(
'Trying to start Legacy and Extended Advertising at the same time!'
)
# If we're advertising, stop first
if self.advertising:
if self.legacy_advertiser:
await self.stop_advertising()
# Set/update the advertising data if the advertising type allows it
if advertising_type.has_data:
await self.send_command(
HCI_LE_Set_Advertising_Data_Command(
advertising_data=self.advertising_data
advertising_data=advertising_data or self.advertising_data or b''
),
check_result=True,
)
@@ -1603,7 +1671,9 @@ class Device(CompositeEventEmitter):
if advertising_type.is_scannable:
await self.send_command(
HCI_LE_Set_Scan_Response_Data_Command(
scan_response_data=self.scan_response_data
scan_response_data=scan_response_data
or self.scan_response_data
or b''
),
check_result=True,
)
@@ -1640,45 +1710,57 @@ class Device(CompositeEventEmitter):
check_result=True,
)
self.advertising_type = advertising_type
self.advertising_own_address_type = own_address_type
self.advertising = True
self.auto_restart_advertising = auto_restart
self.legacy_advertiser = LegacyAdvertiser(
device=self,
advertising_type=advertising_type,
own_address_type=own_address_type,
auto_restart=auto_restart,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
return self.legacy_advertiser
@deprecated("Please use stop_legacy_advertising.")
async def stop_advertising(self) -> None:
await self.stop_legacy_advertising()
async def stop_legacy_advertising(self) -> None:
# Disable advertising
if self.advertising:
if self.legacy_advertiser:
await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
check_result=True,
)
self.advertising_type = None
self.advertising_own_address_type = None
self.advertising = False
self.auto_restart_advertising = False
self.legacy_advertiser = None
@experimental('Extended Advertising is still experimental - Might be changed soon.')
async def start_extended_advertising(
self,
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING,
target: Address = Address.ANY,
own_address_type: int = OwnAddressType.RANDOM,
scan_response: Optional[bytes] = None,
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
auto_restart: bool = True,
advertising_data: Optional[bytes] = None,
) -> int:
scan_response_data: Optional[bytes] = None,
) -> ExtendedAdvertiser:
"""Starts an extended advertising set.
Args:
advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command
target: Directed advertising target. Directed property should be set in advertising_properties arg.
own_address_type: own address type to use in the advertising.
scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
auto_restart: whether the advertisement will be restarted after disconnection.
advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent.
scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
Returns:
Handle of the new advertising set.
ExtendedAdvertiser object containing the metadata of advertisement.
"""
if self.legacy_advertiser:
logger.warning(
'Trying to start Legacy and Extended Advertising at the same time!'
)
adv_handle = -1
# Find a free handle
@@ -1686,7 +1768,7 @@ class Device(CompositeEventEmitter):
DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE,
DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1,
):
if i not in self.extended_advertising_handles:
if i not in self.extended_advertisers:
adv_handle = i
break
@@ -1733,13 +1815,13 @@ class Device(CompositeEventEmitter):
)
# Set the scan response if present
if scan_response is not None:
if scan_response_data is not None:
await self.send_command(
HCI_LE_Set_Extended_Scan_Response_Data_Command(
advertising_handle=adv_handle,
operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA,
fragment_preference=0x01, # Should not fragment
scan_response_data=scan_response,
scan_response_data=scan_response_data,
),
check_result=True,
)
@@ -1774,8 +1856,16 @@ class Device(CompositeEventEmitter):
)
raise error
self.extended_advertising_handles.add(adv_handle)
return adv_handle
advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser(
device=self,
handle=adv_handle,
advertising_properties=advertising_properties,
own_address_type=own_address_type,
auto_restart=auto_restart,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
return advertiser
@experimental('Extended Advertising is still experimental - Might be changed soon.')
async def stop_extended_advertising(self, adv_handle: int) -> None:
@@ -1799,11 +1889,11 @@ class Device(CompositeEventEmitter):
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle),
check_result=True,
)
self.extended_advertising_handles.remove(adv_handle)
del self.extended_advertisers[adv_handle]
@property
def is_advertising(self):
return self.advertising
return self.legacy_advertiser or self.extended_advertisers
async def start_scanning(
self,
@@ -3144,13 +3234,18 @@ class Device(CompositeEventEmitter):
# Guess which own address type is used for this connection.
# This logic is somewhat correct but may need to be improved
# when multiple advertising are run simultaneously.
advertiser = None
if self.connect_own_address_type is not None:
own_address_type = self.connect_own_address_type
elif self.legacy_advertiser:
own_address_type = self.legacy_advertiser.own_address_type
# Store advertiser for restarting - it's only required for legacy, since
# extended advertisement produces HCI_Advertising_Set_Terminated.
if self.legacy_advertiser.auto_restart:
advertiser = self.legacy_advertiser
else:
own_address_type = self.advertising_own_address_type
# We are no longer advertising
self.advertising = False
# For extended advertisement, determining own address type later.
own_address_type = OwnAddressType.RANDOM
if own_address_type in (
OwnAddressType.PUBLIC,
@@ -3172,6 +3267,7 @@ class Device(CompositeEventEmitter):
connection_parameters,
ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY),
)
connection.advertiser_after_disconnection = advertiser
self.connections[connection_handle] = connection
# If supported, read which PHY we're connected with before
@@ -3203,10 +3299,10 @@ class Device(CompositeEventEmitter):
# For directed advertising, this means a timeout
if (
transport == BT_LE_TRANSPORT
and self.advertising
and self.advertising_type.is_directed
and self.legacy_advertiser
and self.legacy_advertiser.advertising_type.is_directed
):
self.advertising = False
self.legacy_advertiser = None
# Notify listeners
error = core.ConnectionError(
@@ -3268,16 +3364,30 @@ class Device(CompositeEventEmitter):
self.gatt_server.on_disconnection(connection)
# Restart advertising if auto-restart is enabled
if self.auto_restart_advertising:
if advertiser := connection.advertiser_after_disconnection:
logger.debug('restarting advertising')
self.abort_on(
'flush',
self.start_advertising(
advertising_type=self.advertising_type, # type: ignore[arg-type]
own_address_type=self.advertising_own_address_type, # type: ignore[arg-type]
auto_restart=True,
),
)
if isinstance(advertiser, LegacyAdvertiser):
self.abort_on(
'flush',
self.start_legacy_advertising(
advertising_type=advertiser.advertising_type,
own_address_type=advertiser.own_address_type,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
auto_restart=True,
),
)
elif isinstance(advertiser, ExtendedAdvertiser):
self.abort_on(
'flush',
self.start_extended_advertising(
advertising_properties=advertiser.advertising_properties,
own_address_type=advertiser.own_address_type,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
auto_restart=True,
),
)
elif sco_link := self.sco_links.pop(connection_handle, None):
sco_link.emit('disconnection', reason)
elif cis_link := self.cis_links.pop(connection_handle, None):
@@ -3600,6 +3710,30 @@ class Device(CompositeEventEmitter):
if sco_link := self.sco_links.get(sco_handle, None):
sco_link.emit('pdu', packet)
# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_advertising_set_termination(
self,
status: int,
advertising_handle: int,
connection_handle: int,
) -> None:
if status == HCI_SUCCESS:
connection = self.lookup_connection(connection_handle)
if advertiser := self.extended_advertisers.pop(advertising_handle, None):
if connection:
if advertiser.auto_restart:
connection.advertiser_after_disconnection = advertiser
if advertiser.own_address_type in (
OwnAddressType.PUBLIC,
OwnAddressType.RESOLVABLE_OR_PUBLIC,
):
connection.self_address = self.public_address
else:
connection.self_address = self.random_address
advertiser.emit('termination', status)
# [LE only]
@host_event_handler
@with_connection_from_handle
+28 -32
View File
@@ -19,12 +19,17 @@ like loading firmware after a cold start.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
from __future__ import annotations
import logging
import pathlib
import platform
from . import rtk
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
@@ -32,40 +37,31 @@ from . import rtk
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host):
"""Probe all known diver classes until one returns a valid instance for a host,
or none is found.
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
"""
if driver := await rtk.Driver.for_host(host):
logger.debug("Instantiated RTK driver")
return driver
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
+11 -4
View File
@@ -41,7 +41,7 @@ from bumble.hci import (
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common
# -----------------------------------------------------------------------------
# Logging
@@ -285,7 +285,7 @@ class Firmware:
)
class Driver:
class Driver(common.Driver):
@dataclass
class DriverInfo:
rom: int
@@ -470,8 +470,12 @@ class Driver:
logger.debug("USB metadata not found")
return False
vendor_id = host.hci_metadata.get("vendor_id", None)
product_id = host.hci_metadata.get("product_id", None)
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
@@ -486,6 +490,9 @@ class Driver:
@classmethod
async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
+1 -1
View File
@@ -961,7 +961,7 @@ class Server(EventEmitter):
try:
attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
logger.exception(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''
+41 -19
View File
@@ -728,6 +728,19 @@ HCI_LE_PHY_TYPE_TO_BIT = {
HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT
}
class Phy(enum.IntEnum):
LE_1M = 0x01
LE_2M = 0x02
LE_CODED = 0x03
class PhyBit(enum.IntFlag):
LE_1M = 0b00000001
LE_2M = 0b00000010
LE_CODED = 0b00000100
# Connection Parameters
HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25
HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25
@@ -1963,25 +1976,15 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_
# -----------------------------------------------------------------------------
class OwnAddressType:
class OwnAddressType(enum.IntEnum):
PUBLIC = 0
RANDOM = 1
RESOLVABLE_OR_PUBLIC = 2
RESOLVABLE_OR_RANDOM = 3
TYPE_NAMES = {
PUBLIC: 'PUBLIC',
RANDOM: 'RANDOM',
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM',
}
@staticmethod
def type_name(type_id):
return name_or_number(OwnAddressType.TYPE_NAMES, type_id)
# pylint: disable-next=unnecessary-lambda
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
@classmethod
def type_spec(cls):
return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
# -----------------------------------------------------------------------------
@@ -3374,7 +3377,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
),
},
),
('own_address_type', OwnAddressType.TYPE_SPEC),
('own_address_type', OwnAddressType.type_spec()),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('advertising_channel_map', 1),
@@ -3467,7 +3470,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
('le_scan_type', 1),
('le_scan_interval', 2),
('le_scan_window', 2),
('own_address_type', OwnAddressType.TYPE_SPEC),
('own_address_type', OwnAddressType.type_spec()),
('scanning_filter_policy', 1),
]
)
@@ -3506,7 +3509,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
('initiator_filter_policy', 1),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('own_address_type', OwnAddressType.TYPE_SPEC),
('own_address_type', OwnAddressType.type_spec()),
('connection_interval_min', 2),
('connection_interval_max', 2),
('max_latency', 2),
@@ -3913,7 +3916,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
),
},
),
('own_address_type', OwnAddressType.TYPE_SPEC),
('own_address_type', OwnAddressType.type_spec()),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('advertising_filter_policy', 1),
@@ -4309,7 +4312,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
('initiator_filter_policy:', self.initiator_filter_policy),
(
'own_address_type: ',
OwnAddressType.type_name(self.own_address_type),
OwnAddressType(self.own_address_type).name,
),
(
'peer_address_type: ',
@@ -4551,6 +4554,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command):
See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command
'''
class Direction(enum.IntEnum):
HOST_TO_CONTROLLER = 0x00
CONTROLLER_TO_HOST = 0x01
connection_handle: int
data_path_direction: int
data_path_id: int
@@ -5190,6 +5197,21 @@ HCI_LE_Meta_Event.subevent_classes[
] = HCI_LE_Extended_Advertising_Report_Event
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
('status', 1),
('advertising_handle', 1),
('connection_handle', 2),
('number_completed_extended_advertising_events', 1),
]
)
class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event):
'''
See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event
'''
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)])
class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
+131 -90
View File
@@ -21,7 +21,7 @@ import collections
import logging
import struct
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast
from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
@@ -91,16 +91,49 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
class AclPacketQueue:
max_packet_size: int
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
def __init__(
self,
max_packet_size: int,
max_in_flight: int,
send: Callable[[HCI_Packet], None],
) -> None:
self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight
self.in_flight = 0
self.send = send
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
# fmt: on
def enqueue(self, packet: HCI_AclDataPacket) -> None:
self.packets.appendleft(packet)
self.check_queue()
if self.packets:
logger.debug(
f'{self.in_flight} ACL packets in flight, '
f'{len(self.packets)} in queue'
)
def check_queue(self) -> None:
while self.packets and self.in_flight < self.max_in_flight:
packet = self.packets.pop()
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None:
if packet_count > self.in_flight:
logger.warning(
color(
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
)
packet_count = self.in_flight
self.in_flight -= packet_count
self.check_queue()
# -----------------------------------------------------------------------------
@@ -111,6 +144,13 @@ class Connection:
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = (
host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT
else host.acl_packet_queue
)
assert acl_packet_queue
self.acl_packet_queue = acl_packet_queue
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet)
@@ -123,8 +163,10 @@ class Connection:
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
@@ -137,18 +179,11 @@ class Host(AbortableEventEmitter):
) -> None:
super().__init__()
self.hci_metadata = None
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
@@ -162,10 +197,7 @@ class Host(AbortableEventEmitter):
# Connect to the source and sink if specified
if controller_source:
controller_source.set_packet_sink(self)
self.hci_metadata = getattr(
controller_source, 'metadata', self.hci_metadata
)
self.set_packet_source(controller_source)
if controller_sink:
self.set_packet_sink(controller_sink)
@@ -200,17 +232,21 @@ class Host(AbortableEventEmitter):
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
# Instantiate and init a driver for the host if needed.
# NOTE: we don't keep a reference to the driver here, because we don't
# currently have a need for the driver later on. But if the driver interface
# evolves, it may be required, then, to store a reference to the driver in
# an object property.
reset_needed = True
if driver_factory is not None:
if driver := await driver_factory(self):
await driver.init_controller()
reset_needed = False
# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True
@@ -253,46 +289,54 @@ class Host(AbortableEventEmitter):
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
self.hc_acl_data_packet_length = (
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
)
self.acl_packet_queue = AclPacketQueue(
max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet,
)
hc_le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue
else:
# Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue(
max_packet_size=hc_le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
@@ -313,29 +357,31 @@ class Host(AbortableEventEmitter):
)
)
self.reset_done = True
@property
def controller(self) -> TransportSink:
def controller(self) -> Optional[TransportSink]:
return self.hci_sink
@controller.setter
def controller(self, controller):
def controller(self, controller) -> None:
self.set_packet_sink(controller)
if controller:
controller.set_packet_sink(self)
def set_packet_sink(self, sink: TransportSink) -> None:
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
self.hci_sink = sink
def set_packet_source(self, source: TransportSource) -> None:
source.set_packet_sink(self)
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
def send_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet))
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -383,6 +429,17 @@ class Host(AbortableEventEmitter):
asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
packet_queue = connection.acl_packet_queue
if packet_queue is None:
logger.warning(
f'no ACL packet queue for connection 0x{connection_handle:04X}'
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
@@ -390,8 +447,7 @@ class Host(AbortableEventEmitter):
offset = 0
pb_flag = 0
while bytes_remaining:
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
acl_packet = HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
@@ -399,34 +455,12 @@ class Host(AbortableEventEmitter):
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
)
self.queue_acl_packet(acl_packet)
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet)
pb_flag = 1
offset += data_total_length
bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet)
self.acl_packets_in_flight += 1
def supports_command(self, command):
# Find the support flag position for this command
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
@@ -549,7 +583,7 @@ class Host(AbortableEventEmitter):
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
return None
return
return self.on_command_processed(event)
@@ -557,18 +591,17 @@ class Host(AbortableEventEmitter):
return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event):
total_packets = sum(event.num_completed_packets)
if total_packets <= self.acl_packets_in_flight:
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(
color(
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets
):
if not (connection := self.connections.get(connection_handle)):
logger.warning(
'received packet completion event for unknown handle '
f'0x{connection_handle:04X}'
)
)
self.acl_packets_in_flight = 0
continue
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
# Classic only
def on_hci_connection_request_event(self, event):
@@ -721,6 +754,14 @@ class Host(AbortableEventEmitter):
def on_hci_le_extended_advertising_report_event(self, event):
self.on_hci_le_advertising_report_event(event)
def on_hci_le_advertising_set_terminated_event(self, event):
self.emit(
'advertising_set_termination',
event.status,
event.advertising_handle,
event.connection_handle,
)
def on_hci_le_cis_request_event(self, event):
self.emit(
'cis_request',
+2 -2
View File
@@ -151,8 +151,8 @@ L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
+752 -1
View File
@@ -23,13 +23,21 @@ import dataclasses
import enum
import struct
import functools
from typing import Optional, List, Union
import logging
from typing import Optional, List, Union, Type, Dict, Any, Tuple, cast
from bumble import colors
from bumble import device
from bumble import hci
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -220,6 +228,231 @@ class SupportedFrameDuration(enum.IntFlag):
DURATION_10000_US_PREFERRED = 0b0010
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -325,6 +558,80 @@ class CodecSpecificCapabilities:
)
@dataclasses.dataclass
class CodecSpecificConfiguration:
'''See:
* Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures
* Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements
'''
class Type(enum.IntEnum):
# fmt: off
SAMPLING_FREQUENCY = 0x01
FRAME_DURATION = 0x02
AUDIO_CHANNEL_ALLOCATION = 0x03
OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05
sampling_frequency: SamplingFrequency
frame_duration: FrameDuration
audio_channel_allocation: AudioLocation
octets_per_codec_frame: int
codec_frames_per_sdu: int
@classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
offset = 0
# Allowed default values.
audio_channel_allocation = AudioLocation.NOT_ALLOWED
codec_frames_per_sdu = 1
while offset < len(data):
length, type = struct.unpack_from('BB', data, offset)
offset += 2
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY:
sampling_frequency = SamplingFrequency(value)
elif type == CodecSpecificConfiguration.Type.FRAME_DURATION:
frame_duration = FrameDuration(value)
elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION:
audio_channel_allocation = AudioLocation(value)
elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME:
octets_per_codec_frame = value
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
return CodecSpecificConfiguration(
sampling_frequency=sampling_frequency,
frame_duration=frame_duration,
audio_channel_allocation=audio_channel_allocation,
octets_per_codec_frame=octets_per_codec_frame,
codec_frames_per_sdu=codec_frames_per_sdu,
)
def __bytes__(self) -> bytes:
return struct.pack(
'<BBBBBBBBIBBHBBB',
2,
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
self.sampling_frequency,
2,
CodecSpecificConfiguration.Type.FRAME_DURATION,
self.frame_duration,
5,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
self.audio_channel_allocation,
3,
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
self.octets_per_codec_frame,
2,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
self.codec_frames_per_sdu,
)
@dataclasses.dataclass
class PacRecord:
coding_format: hci.CodingFormat
@@ -452,6 +759,429 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
super().__init__(characteristics)
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
# TODO: Parse this
metadata = b''
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if cis_id == self.cis_id and self.state == self.State.ENABLING:
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING:
self.state = self.State.STREAMING
self.cis_link = cis_link
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.DISABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: device.Connection) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = [],
sink_ase_id: Sequence[int] = [],
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def on_write_ase_control_point(self, connection, data):
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
@@ -494,3 +1224,24 @@ class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]
+24 -17
View File
@@ -118,8 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DEFAULT_WINDOW_SIZE = 16
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -438,14 +438,16 @@ class DLC(EventEmitter):
multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
initial_tx_credits: int,
window_size: int,
) -> None:
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rx_credits = window_size
self.rx_threshold = window_size // 2
self.tx_credits = window_size
self.tx_buffer = b''
self.state = DLC.State.INIT
self.role = multiplexer.role
@@ -537,11 +539,11 @@ class DLC(EventEmitter):
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -580,9 +582,9 @@ class DLC(EventEmitter):
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=self.max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
window_size=self.window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
@@ -591,7 +593,7 @@ class DLC(EventEmitter):
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return self.window_size - self.rx_credits
return 0
@@ -843,7 +845,12 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(self, channel: int) -> DLC:
async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
) -> DLC:
if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress')
@@ -855,9 +862,9 @@ class Multiplexer(EventEmitter):
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
window_size=window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
+49 -21
View File
@@ -18,6 +18,7 @@
from contextlib import asynccontextmanager
import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..snoop import create_snooper
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport:
"""
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The name must be <type>:<metadata><parameters>
Where <parameters> depend on the type (and may be empty for some types), and
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are:
* serial
* udp
@@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport:
* android-netsim
"""
return _wrap_transport(await _open_transport(name))
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
else:
metadata = None
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
return await open_serial_transport(spec)
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
return await open_udp_transport(spec)
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
return await open_tcp_client_transport(spec)
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
return await open_tcp_server_transport(spec)
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
return await open_ws_client_transport(spec)
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
return await open_ws_server_transport(spec)
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
return await open_pty_transport(spec)
if scheme == 'file':
from .file import open_file_transport
assert spec is not None
return await open_file_transport(spec[0])
return await open_file_transport(spec)
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
return await open_vhci_transport(spec)
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
return await open_hci_socket_transport(spec)
if scheme == 'usb':
from .usb import open_usb_transport
assert spec is not None
return await open_usb_transport(spec[0])
assert spec
return await open_usb_transport(spec)
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
assert spec is not None
return await open_pyusb_transport(spec[0])
assert spec
return await open_pyusb_transport(spec)
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
return await open_android_emulator_transport(spec)
if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec[0] if spec else None)
return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme')
+1 -1
View File
@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
mode = 'host'
server_host = 'localhost'
server_port = '8554'
if spec is not None:
if spec:
params = spec.split(',')
for param in params:
if param.startswith('mode='):
+1 -1
View File
@@ -21,7 +21,7 @@ import struct
import asyncio
import logging
import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import hci
from bumble.colors import color
+1 -4
View File
@@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
) from error
# Compute the adapter index
if spec is None:
adapter_index = 0
else:
adapter_index = int(spec)
adapter_index = int(spec) if spec else 0
# Bind the socket
# NOTE: since Python doesn't support binding with the required address format (yet),
+1 -1
View File
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
READ_SIZE = 4096
class UsbPacketSink:
def __init__(self, device, acl_out):
+9
View File
@@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly.
This may include, for instance, loading a Firmware image or patch,
loading a configuration.
By default, drivers will be automatically probed to determine if they should be
used with particular HCI controller.
When the transport for an HCI controller is instantiated from a transport name,
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
metadata portion of the transport name. For example,
``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the
first USB device, even if a normal probe would not have selected it based on the
USB vendor ID and product ID.
Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
+5 -2
View File
@@ -1,13 +1,16 @@
REALTEK DRIVER
==============
This driver supports loading firmware images and optional config data to
This driver supports loading firmware images and optional config data to
USB dongles with a Realtek chipset.
A number of USB dongles are supported, but likely not all.
When using a USB dongle, the USB product ID and manufacturer ID are used
When using a USB dongle, the USB product ID and vendor ID are used
to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for those files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
+1
View File
@@ -1,5 +1,6 @@
{
"name": "Bumble-LEA",
"keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"advertising_interval": 100
}
+51 -1
View File
@@ -19,12 +19,14 @@ import asyncio
import logging
import sys
import os
import struct
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.device import Device, CisLink
from bumble.hci import (
CodecID,
CodingFormat,
OwnAddressType,
HCI_IsoDataPacket,
HCI_LE_Set_Extended_Advertising_Parameters_Command,
)
from bumble.profiles.bap import (
@@ -35,6 +37,7 @@ from bumble.profiles.bap import (
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
)
from bumble.transport import open_transport_or_link
@@ -103,6 +106,8 @@ async def main() -> None:
)
)
device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2]))
advertising_data = bytes(
AdvertisingData(
[
@@ -110,6 +115,16 @@ async def main() -> None:
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
@@ -117,6 +132,41 @@ async def main() -> None:
]
)
)
subprocess = await asyncio.create_subprocess_shell(
f'dlc3 | ffplay pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdin = subprocess.stdin
assert stdin
# Write a fake LC3 header to dlc3.
stdin.write(
bytes([0x1C, 0xCC]) # Header.
+ struct.pack(
'<HHHHHHI',
18, # Header length.
24000 // 100, # Sampling Rate(/100Hz).
0, # Bitrate(unused).
1, # Channels.
10000 // 10, # Frame duration(/10us).
0, # RFU.
0x0FFFFFFF, # Frame counts.
)
)
def on_pdu(pdu: HCI_IsoDataPacket):
# LC3 format: |frame_length(2)| + |frame(length)|.
if pdu.iso_sdu_length:
stdin.write(struct.pack('<H', pdu.iso_sdu_length))
stdin.write(pdu.iso_sdu_fragment)
def on_cis(cis_link: CisLink):
cis_link.on('pdu', on_pdu)
device.once('cis_establishment', on_cis)
await device.start_extended_advertising(
advertising_properties=(
@@ -42,6 +42,7 @@ public class HciServer {
try (ServerSocket serverSocket = new ServerSocket(mPort)) {
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
try (Socket clientSocket = serverSocket.accept()) {
clientSocket.setTcpNoDelay(true);
mListener.onHostConnectionState(true);
mListener.onMessage("Connected");
HciParser parser = new HciParser(mListener);
+251
View File
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import asyncio
import os
import functools
import pytest
import logging
@@ -24,11 +25,26 @@ from bumble import device
from bumble.hci import CodecID, CodingFormat
from bumble.profiles.bap import (
AudioLocation,
AseStateMachine,
ASE_Operation,
ASE_Config_Codec,
ASE_Config_QOS,
ASE_Disable,
ASE_Enable,
ASE_Receiver_Start_Ready,
ASE_Receiver_Stop_Ready,
ASE_Release,
ASE_Update_Metadata,
SupportedFrameDuration,
SupportedSamplingFrequency,
SamplingFrequency,
FrameDuration,
CodecSpecificCapabilities,
CodecSpecificConfiguration,
ContextType,
PacRecord,
AudioStreamControlService,
AudioStreamControlServiceProxy,
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)
@@ -40,6 +56,13 @@ from .test_utils import TwoDevices
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def basic_check(operation: ASE_Operation):
serialized = bytes(operation)
parsed = ASE_Operation.from_bytes(serialized)
assert bytes(parsed) == serialized
# -----------------------------------------------------------------------------
def test_codec_specific_capabilities() -> None:
SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
@@ -85,6 +108,92 @@ def test_vendor_specific_pac_record() -> None:
assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA
# -----------------------------------------------------------------------------
def test_ASE_Config_Codec() -> None:
operation = ASE_Config_Codec(
ase_id=[1, 2],
target_latency=[3, 4],
target_phy=[5, 6],
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
codec_specific_configuration=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Config_QOS() -> None:
operation = ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 2],
cis_id=[3, 4],
sdu_interval=[5, 6],
framing=[0, 1],
phy=[2, 3],
max_sdu=[4, 5],
retransmission_number=[6, 7],
max_transport_latency=[8, 9],
presentation_delay=[10, 11],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Enable() -> None:
operation = ASE_Enable(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Update_Metadata() -> None:
operation = ASE_Update_Metadata(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Disable() -> None:
operation = ASE_Disable(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Release() -> None:
operation = ASE_Release(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Receiver_Start_Ready() -> None:
operation = ASE_Receiver_Start_Ready(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Receiver_Stop_Ready() -> None:
operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_codec_specific_configuration() -> None:
SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000
FRAME_SURATION = FrameDuration.DURATION_10000_US
AUDIO_LOCATION = AudioLocation.FRONT_LEFT
config = CodecSpecificConfiguration(
sampling_frequency=SAMPLE_FREQUENCY,
frame_duration=FRAME_SURATION,
audio_channel_allocation=AUDIO_LOCATION,
octets_per_codec_frame=60,
codec_frames_per_sdu=1,
)
assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_pacs():
@@ -140,6 +249,148 @@ async def test_pacs():
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ascs():
devices = TwoDevices()
devices[0].add_service(
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
ascs_client = await peer.discover_service_and_create_proxy(
AudioStreamControlServiceProxy
)
notifications = {1: asyncio.Queue(), 2: asyncio.Queue()}
def on_notification(data: bytes, ase_id: int):
notifications[ase_id].put_nowait(data)
# Should be idle
assert await ascs_client.sink_ase[0].read_value() == bytes(
[1, AseStateMachine.State.IDLE]
)
assert await ascs_client.sink_ase[1].read_value() == bytes(
[2, AseStateMachine.State.IDLE]
)
# Subscribe
await ascs_client.sink_ase[0].subscribe(
functools.partial(on_notification, ase_id=1)
)
await ascs_client.sink_ase[1].subscribe(
functools.partial(on_notification, ase_id=2)
)
# Config Codec
config = CodecSpecificConfiguration(
sampling_frequency=SamplingFrequency.FREQ_48000,
frame_duration=FrameDuration.DURATION_10000_US,
audio_channel_allocation=AudioLocation.FRONT_LEFT,
octets_per_codec_frame=120,
codec_frames_per_sdu=1,
)
await ascs_client.ase_control_point.write_value(
ASE_Config_Codec(
ase_id=[1, 2],
target_latency=[3, 4],
target_phy=[5, 6],
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
codec_specific_configuration=[config, config],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.CODEC_CONFIGURED]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.CODEC_CONFIGURED]
)
# Config QOS
await ascs_client.ase_control_point.write_value(
ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 2],
cis_id=[3, 4],
sdu_interval=[5, 6],
framing=[0, 1],
phy=[2, 3],
max_sdu=[4, 5],
retransmission_number=[6, 7],
max_transport_latency=[8, 9],
presentation_delay=[10, 11],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.QOS_CONFIGURED]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.QOS_CONFIGURED]
)
# Enable
await ascs_client.ase_control_point.write_value(
ASE_Enable(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.ENABLING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.ENABLING]
)
# CIS establishment
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=5,
cis_id=3,
cig_id=1,
),
)
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=6,
cis_id=4,
cig_id=2,
),
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.STREAMING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.STREAMING]
)
# Release
await ascs_client.ase_control_point.write_value(
ASE_Release(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.RELEASING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.RELEASING]
)
assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
await asyncio.sleep(0.001)
# -----------------------------------------------------------------------------
async def run():
await test_pacs()
+182 -2
View File
@@ -20,16 +20,23 @@ import logging
import os
from types import LambdaType
import pytest
from unittest import mock
from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
ConnectionParameters,
)
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.host import AclPacketQueue, Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING,
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
OwnAddressType,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
@@ -66,6 +73,13 @@ async def test_device_connect_parallel():
d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None))
def _send(packet):
pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
# enable classic
d0.classic_enabled = True
d1.classic_enabled = True
@@ -232,6 +246,172 @@ async def test_flush():
pass
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_legacy_advertising():
device = Device(host=mock.AsyncMock(Host))
# Start advertising
advertiser = await device.start_legacy_advertising()
assert device.legacy_advertiser
# Stop advertising
await advertiser.stop()
assert not device.legacy_advertiser
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_legacy_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
# Start advertising
advertiser = await device.start_legacy_advertising()
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
# For unknown reason, read_phy() in on_connection() would be killed at the end of
# test, so we force scheduling here to avoid an warning.
await asyncio.sleep(0.0001)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
(True, False),
)
@pytest.mark.asyncio
async def test_legacy_advertising_disconnection(auto_restart):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_legacy_advertising(auto_restart=auto_restart)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.start_legacy_advertising = mock.AsyncMock()
device.on_disconnection(0x0001, 0)
if auto_restart:
device.start_legacy_advertising.assert_called_with(
advertising_type=advertiser.advertising_type,
own_address_type=advertiser.own_address_type,
auto_restart=advertiser.auto_restart,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
)
else:
device.start_legacy_advertising.assert_not_called()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_extended_advertising():
device = Device(host=mock.AsyncMock(Host))
# Start advertising
advertiser = await device.start_extended_advertising()
assert device.extended_advertisers
# Stop advertising
await advertiser.stop()
assert not device.extended_advertisers
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_extended_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_extended_advertising(
own_address_type=own_address_type
)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertiser.handle,
0x0001,
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
# For unknown reason, read_phy() in on_connection() would be killed at the end of
# test, so we force scheduling here to avoid an warning.
await asyncio.sleep(0.0001)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
(True, False),
)
@pytest.mark.asyncio
async def test_extended_advertising_disconnection(auto_restart):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_extended_advertising(auto_restart=auto_restart)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertiser.handle,
0x0001,
)
device.start_extended_advertising = mock.AsyncMock()
device.on_disconnection(0x0001, 0)
if auto_restart:
device.start_extended_advertising.assert_called_with(
advertising_properties=advertiser.advertising_properties,
own_address_type=advertiser.own_address_type,
auto_restart=advertiser.auto_restart,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
)
else:
device.start_extended_advertising.assert_not_called()
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))