Compare commits

...

17 Commits

Author SHA1 Message Date
khsiao-google
ca23d6b89a Revert "Improve connection related functions and names" 2025-09-10 15:00:41 +08:00
khsiao-google
d86d69d816 Merge pull request #771 from khsiao-google/update
Improve connection related functions and names
2025-09-10 14:56:38 +08:00
khsiao-google
dc93f32a9a Replace core.ConnectionParameters by Connection.Parameters in device.py 2025-09-08 02:00:49 +00:00
zxzxwu
9838908a26 Merge pull request #772 from zxzxwu/hap
HAP: Slightly Pythonic refactor
2025-09-05 23:08:09 +08:00
Josh Wu
613519f0b3 HAP: Slightly Pythonic refactor
* Add missing type annotations
* Avoid __value__ and _ arguments (this will be a problem for override).
* Replace while-pop with for loop
2025-09-05 21:02:16 +08:00
zxzxwu
a943ea57ef Merge pull request #770 from zxzxwu/avrcp
AVRCP: Implement most commands and responses
2025-09-04 16:18:54 +08:00
Josh Wu
14401910bb AVRCP: Implement most commands and responses 2025-09-03 13:20:10 +08:00
khsiao-google
5d35ed471c Merge pull request #769 from khsiao-google/update
Add typing for host.py
2025-09-02 14:59:27 +08:00
khsiao-google
c720ad5fdc Add typing for host.py 2025-09-02 06:01:39 +00:00
khsiao-google
f02183f95d Merge pull request #764 from khsiao-google/update
Add typing for device.py
2025-09-01 15:19:57 +08:00
khsiao-google
d903937a51 Merge branch 'main' into update 2025-09-01 07:14:19 +00:00
zxzxwu
6381ee0ab1 Merge pull request #767 from zxzxwu/avrcp
Migrate AVRCP packets to dataclasses
2025-09-01 13:26:56 +08:00
Gilles Boccon-Gibod
59d99780e1 Merge pull request #768 from google/gbg/data-types
add support for data type classes
2025-08-30 13:04:32 -07:00
Josh Wu
9f3d8c9b49 Migrate AVRCP responses to dataclasses 2025-08-28 21:42:38 +08:00
Josh Wu
31961febe5 Migrate AVRCP events to dataclasses 2025-08-28 17:00:20 +08:00
Josh Wu
dab0993cba Migrate AVRCP packets to dataclasses 2025-08-28 17:00:20 +08:00
khsiao-google
3333ba472b Add typing for device.py 2025-08-26 09:22:06 +00:00
6 changed files with 1762 additions and 850 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,7 @@ from typing_extensions import Self
from bumble import (
core,
data_types,
gatt,
gatt_client,
gatt_server,
hci,
@@ -264,7 +265,7 @@ class ExtendedAdvertisement(Advertisement):
# -----------------------------------------------------------------------------
class AdvertisementDataAccumulator:
def __init__(self, passive=False):
def __init__(self, passive: bool = False):
self.passive = passive
self.last_advertisement = None
self.last_data = b''
@@ -1243,7 +1244,7 @@ class LePhyOptions:
PREFER_S_2_CODED_PHY = 1
PREFER_S_8_CODED_PHY = 2
def __init__(self, coded_phy_preference=0):
def __init__(self, coded_phy_preference: int = 0):
self.coded_phy_preference = coded_phy_preference
def __int__(self):
@@ -1691,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter):
self_address: hci.Address
self_resolvable_address: Optional[hci.Address]
peer_address: hci.Address
peer_name: Optional[str]
peer_resolvable_address: Optional[hci.Address]
peer_le_features: Optional[hci.LeFeatureMask]
role: hci.Role
@@ -1931,7 +1933,7 @@ class Connection(utils.CompositeEventEmitter):
self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result)
self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None:
async def set_data_length(self, tx_octets: int, tx_time: int) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)
async def update_parameters(
@@ -1961,7 +1963,12 @@ class Connection(utils.CompositeEventEmitter):
use_l2cap=use_l2cap,
)
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
async def set_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options)
async def get_phy(self) -> ConnectionPHY:
@@ -2156,7 +2163,7 @@ class DeviceConfiguration:
# Decorator that converts the first argument from a connection handle to a connection
def with_connection_from_handle(function):
@functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs):
def wrapper(self, connection_handle: int, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None:
raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
@@ -2169,7 +2176,7 @@ def with_connection_from_handle(function):
# Decorator that converts the first argument from a bluetooth address to a connection
def with_connection_from_address(function):
@functools.wraps(function)
def wrapper(self, address, *args, **kwargs):
def wrapper(self, address: hci.Address, *args, **kwargs):
if connection := self.pending_connections.get(address, False):
return function(self, connection, *args, **kwargs)
for connection in self.connections.values():
@@ -2655,7 +2662,7 @@ class Device(utils.CompositeEventEmitter):
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
async def send_command(self, command, check_result=False):
async def send_command(self, command: hci.HCI_Command, check_result: bool = False):
try:
return await asyncio.wait_for(
self.host.send_command(command, check_result), self.command_timeout
@@ -2912,13 +2919,13 @@ class Device(utils.CompositeEventEmitter):
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
return self.host.supports_le_features(feature)
def supports_le_phy(self, phy: int) -> bool:
if phy == hci.HCI_LE_1M_PHY:
def supports_le_phy(self, phy: hci.Phy) -> bool:
if phy == hci.Phy.LE_1M:
return True
feature_map: dict[int, hci.LeFeatureMask] = {
hci.HCI_LE_2M_PHY: hci.LeFeatureMask.LE_2M_PHY,
hci.HCI_LE_CODED_PHY: hci.LeFeatureMask.LE_CODED_PHY,
feature_map: dict[hci.Phy, hci.LeFeatureMask] = {
hci.Phy.LE_2M: hci.LeFeatureMask.LE_2M_PHY,
hci.Phy.LE_CODED: hci.LeFeatureMask.LE_CODED_PHY,
}
if phy not in feature_map:
raise InvalidArgumentError('invalid PHY')
@@ -3522,7 +3529,9 @@ class Device(utils.CompositeEventEmitter):
self.discovering = False
@host_event_handler
def on_inquiry_result(self, address, class_of_device, data, rssi):
def on_inquiry_result(
self, address: hci.Address, class_of_device: int, data: bytes, rssi: int
):
self.emit(
self.EVENT_INQUIRY_RESULT,
address,
@@ -3531,7 +3540,9 @@ class Device(utils.CompositeEventEmitter):
rssi,
)
async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled):
async def set_scan_enable(
self, inquiry_scan_enabled: bool, page_scan_enabled: bool
):
if inquiry_scan_enabled and page_scan_enabled:
scan_enable = 0x03
elif page_scan_enabled:
@@ -3657,6 +3668,7 @@ class Device(utils.CompositeEventEmitter):
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, transport
) # TODO: timeout
@@ -3684,7 +3696,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if transport == PhysicalTransport.LE or (
# match BR/EDR connection failure event against peer address
error.transport == transport
@@ -3904,6 +3916,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -3962,7 +3975,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if (
error.transport == PhysicalTransport.BR_EDR
and error.peer_address == peer_address
@@ -3999,7 +4012,7 @@ class Device(utils.CompositeEventEmitter):
self.pending_connections.pop(peer_address, None)
@asynccontextmanager
async def connect_as_gatt(self, peer_address):
async def connect_as_gatt(self, peer_address: Union[hci.Address, str]):
async with AsyncExitStack() as stack:
connection = await stack.enter_async_context(
await self.connect(peer_address)
@@ -4035,6 +4048,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -4080,7 +4094,9 @@ class Device(utils.CompositeEventEmitter):
)
self.disconnecting = False
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
async def set_data_length(
self, connection: Connection, tx_octets: int, tx_time: int
) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB:
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
@@ -4183,7 +4199,11 @@ class Device(utils.CompositeEventEmitter):
)
async def set_connection_phy(
self, connection, tx_phys=None, rx_phys=None, phy_options=None
self,
connection: Connection,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
if not self.host.supports_command(hci.HCI_LE_SET_PHY_COMMAND):
logger.warning('ignoring request, command not supported')
@@ -4199,7 +4219,7 @@ class Device(utils.CompositeEventEmitter):
all_phys=all_phys_bits,
tx_phys=hci.phy_list_to_bits(tx_phys),
rx_phys=hci.phy_list_to_bits(rx_phys),
phy_options=0 if phy_options is None else int(phy_options),
phy_options=phy_options,
)
)
@@ -4210,7 +4230,11 @@ class Device(utils.CompositeEventEmitter):
)
raise hci.HCI_StatusError(result)
async def set_default_phy(self, tx_phys=None, rx_phys=None):
async def set_default_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
):
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
)
@@ -4248,7 +4272,7 @@ class Device(utils.CompositeEventEmitter):
check_result=True,
)
async def find_peer_by_name(self, name, transport=PhysicalTransport.LE):
async def find_peer_by_name(self, name: str, transport=PhysicalTransport.LE):
"""
Scan for a peer with a given name and return its address.
"""
@@ -4263,7 +4287,7 @@ class Device(utils.CompositeEventEmitter):
if local_name == name:
peer_address.set_result(address)
listener = None
listener: Optional[Callable[..., None]] = None
was_scanning = self.scanning
was_discovering = self.discovering
try:
@@ -4369,10 +4393,10 @@ class Device(utils.CompositeEventEmitter):
def smp_session_proxy(self, session_proxy: type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy
async def pair(self, connection):
async def pair(self, connection: Connection):
return await self.smp_manager.pair(connection)
def request_pairing(self, connection):
def request_pairing(self, connection: Connection):
return self.smp_manager.request_pairing(connection)
async def get_long_term_key(
@@ -4460,7 +4484,7 @@ class Device(utils.CompositeEventEmitter):
on_authentication_failure,
)
async def encrypt(self, connection, enable=True):
async def encrypt(self, connection: Connection, enable: bool = True):
if not enable and connection.transport == PhysicalTransport.LE:
raise InvalidArgumentError('`enable` parameter is classic only.')
@@ -4470,7 +4494,7 @@ class Device(utils.CompositeEventEmitter):
def on_encryption_change():
pending_encryption.set_result(None)
def on_encryption_failure(error_code):
def on_encryption_failure(error_code: int):
pending_encryption.set_exception(hci.HCI_Error(error_code))
connection.on(
@@ -4562,10 +4586,10 @@ class Device(utils.CompositeEventEmitter):
async def switch_role(self, connection: Connection, role: hci.Role):
pending_role_change = asyncio.get_running_loop().create_future()
def on_role_change(new_role):
def on_role_change(new_role: hci.Role):
pending_role_change.set_result(new_role)
def on_role_change_failure(error_code):
def on_role_change_failure(error_code: int):
pending_role_change.set_exception(hci.HCI_Error(error_code))
connection.on(connection.EVENT_ROLE_CHANGE, on_role_change)
@@ -5155,10 +5179,10 @@ class Device(utils.CompositeEventEmitter):
):
connection.emit(connection.EVENT_LINK_KEY)
def add_service(self, service):
def add_service(self, service: gatt.Service):
self.gatt_server.add_service(service)
def add_services(self, services):
def add_services(self, services: Iterable[gatt.Service]):
self.gatt_server.add_services(services)
def add_default_services(
@@ -5254,10 +5278,10 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
def on_advertising_set_termination(
self,
status,
advertising_handle,
connection_handle,
number_of_completed_extended_advertising_events,
status: int,
advertising_handle: int,
connection_handle: int,
number_of_completed_extended_advertising_events: int,
):
# Legacy advertising set is also one of extended advertising sets.
if not (
@@ -5556,7 +5580,12 @@ class Device(utils.CompositeEventEmitter):
)
@host_event_handler
def on_connection_failure(self, transport, peer_address, error_code):
def on_connection_failure(
self,
transport: hci.PhysicalTransport,
peer_address: hci.Address,
error_code: int,
):
logger.debug(
f'*** Connection failed: {hci.HCI_Constant.error_name(error_code)}'
)
@@ -5675,7 +5704,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication(self, connection):
def on_connection_authentication(self, connection: Connection):
logger.debug(
f'*** Connection Authentication: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -5685,7 +5714,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication_failure(self, connection, error):
def on_connection_authentication_failure(
self, connection: Connection, error: core.ConnectionError
):
logger.debug(
f'*** Connection Authentication Failure: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, error={error}'
@@ -5727,10 +5758,13 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_address
def on_authentication_io_capability_response(
self, connection, io_capability, authentication_requirements
self,
connection: Connection,
io_capability: int,
authentication_requirements: int,
):
connection.peer_pairing_io_capability = io_capability
connection.peer_pairing_authentication_requirements = (
connection.pairing_peer_io_capability = io_capability
connection.pairing_peer_authentication_requirements = (
authentication_requirements
)
@@ -5741,7 +5775,7 @@ class Device(utils.CompositeEventEmitter):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
peer_io_capability = connection.peer_pairing_io_capability
peer_io_capability = connection.pairing_peer_io_capability
async def confirm() -> bool:
# Ask the user to confirm the pairing, without display
@@ -5816,7 +5850,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_request(self, connection) -> None:
def on_authentication_user_passkey_request(self, connection: Connection) -> None:
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5859,7 +5893,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_pin_code_request(self, connection):
def on_pin_code_request(self, connection: Connection):
# Classic legacy pairing
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5903,7 +5937,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_notification(self, connection, passkey):
def on_authentication_user_passkey_notification(
self, connection: Connection, passkey: int
):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5915,14 +5951,15 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name(self, connection: Connection, address, remote_name):
def on_remote_name(
self, connection: Connection, address: hci.Address, remote_name: bytes
):
# Try to decode the name
try:
remote_name = remote_name.decode('utf-8')
if connection:
connection.peer_name = remote_name
connection.peer_name = remote_name.decode('utf-8')
connection.emit(connection.EVENT_REMOTE_NAME)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name.decode('utf-8'))
except UnicodeDecodeError as error:
logger.warning('peer name is not valid UTF-8')
if connection:
@@ -5933,7 +5970,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name_failure(self, connection: Connection, address, error):
def on_remote_name_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
@@ -6134,7 +6173,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_encryption_key_refresh(self, connection):
def on_connection_encryption_key_refresh(self, connection: Connection):
logger.debug(
f'*** Connection Key Refresh: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -6172,7 +6211,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_parameters_update_failure(self, connection, error):
def on_connection_parameters_update_failure(
self, connection: Connection, error: int
):
logger.debug(
f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6182,7 +6223,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update(self, connection, phy):
def on_connection_phy_update(self, connection: Connection, phy: core.ConnectionPHY):
logger.debug(
f'*** Connection PHY Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6192,7 +6233,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update_failure(self, connection, error):
def on_connection_phy_update_failure(self, connection: Connection, error: int):
logger.debug(
f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6221,7 +6262,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection, att_mtu):
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6233,7 +6274,12 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_data_length_change(
self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time
self,
connection: Connection,
max_tx_octets: int,
max_tx_time: int,
max_rx_octets: int,
max_rx_time: int,
):
logger.debug(
f'*** Connection Data Length Change: [0x{connection.handle:04X}] '
@@ -6358,14 +6404,16 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_role_change(self, connection, new_role):
def on_role_change(self, connection: Connection, new_role: hci.Role):
connection.role = new_role
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_role_change_failure(self, connection, address, error):
def on_role_change_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error)
self.emit(self.EVENT_ROLE_CHANGE_FAILURE, address, error)
@@ -6379,7 +6427,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status) -> None:
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
def on_pairing_start(self, connection: Connection) -> None:
@@ -6403,7 +6451,7 @@ class Device(utils.CompositeEventEmitter):
connection.emit(connection.EVENT_PAIRING_FAILURE, reason)
@with_connection_from_handle
def on_gatt_pdu(self, connection, pdu):
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu)
@@ -6425,7 +6473,7 @@ class Device(utils.CompositeEventEmitter):
connection.gatt_server.on_gatt_pdu(connection, att_pdu)
@with_connection_from_handle
def on_smp_pdu(self, connection, pdu):
def on_smp_pdu(self, connection: Connection, pdu: bytes):
self.smp_manager.on_smp_pdu(connection, pdu)
@host_event_handler

View File

@@ -26,7 +26,17 @@ import secrets
import struct
from collections.abc import Sequence
from dataclasses import field
from typing import Any, Callable, ClassVar, Iterable, Optional, TypeVar, Union, cast
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import Self
@@ -111,23 +121,57 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
class SpecableEnum(utils.OpenIntEnum):
@classmethod
def type_spec(cls, size: int):
return {'size': size, 'mapper': lambda x: cls(x).name}
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
'serializer': lambda x: x.to_bytes(size, byteorder),
'parser': lambda data, offset: (
offset + size,
cls(int.from_bytes(data[offset : offset + size], byteorder)),
),
'mapper': lambda x: cls(x).name,
}
@classmethod
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
def type_metadata(
cls,
size: int,
list_begin: bool = False,
list_end: bool = False,
byteorder: Literal['little', 'big'] = 'little',
):
return metadata(
cls.type_spec(size, byteorder),
list_begin=list_begin,
list_end=list_end,
)
class SpecableFlag(enum.IntFlag):
@classmethod
def type_spec(cls, size: int):
return {'size': size, 'mapper': lambda x: cls(x).name}
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
'serializer': lambda x: x.to_bytes(size, byteorder),
'parser': lambda data, offset: (
offset + size,
cls(int.from_bytes(data[offset : offset + size], byteorder)),
),
'mapper': lambda x: cls(x).name,
}
@classmethod
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
def type_metadata(
cls,
size: int,
list_begin: bool = False,
list_end: bool = False,
byteorder: Literal['little', 'big'] = 'little',
):
return metadata(
cls.type_spec(size, byteorder),
list_begin=list_begin,
list_end=list_end,
)
# -----------------------------------------------------------------------------
@@ -6422,7 +6466,9 @@ class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event):
irc: int = field(metadata=metadata(1))
max_pdu: int = field(metadata=metadata(2))
iso_interval: int = field(metadata=metadata(2))
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
connection_handle: Sequence[int] = field(
metadata=metadata(2, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------
@@ -6454,7 +6500,9 @@ class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event):
irc: int = field(metadata=metadata(1))
max_pdu: int = field(metadata=metadata(2))
iso_interval: int = field(metadata=metadata(2))
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
connection_handle: Sequence[int] = field(
metadata=metadata(2, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------

View File

@@ -22,11 +22,16 @@ import collections
import dataclasses
import logging
import struct
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast
from bumble import drivers, hci, utils
from bumble.colors import color
from bumble.core import ConnectionParameters, ConnectionPHY, PhysicalTransport
from bumble.core import (
ConnectionParameters,
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError
@@ -902,10 +907,14 @@ class Host(utils.EventEmitter):
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(self, event):
def on_command_processed(
self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event]
):
if self.pending_response:
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
if self.pending_command is None:
logger.warning('!!! pending_command is None ')
elif self.pending_command.op_code != event.command_opcode:
logger.warning(
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
@@ -919,10 +928,10 @@ class Host(utils.EventEmitter):
############################################################
# HCI handlers
############################################################
def on_hci_event(self, event):
def on_hci_event(self, event: hci.HCI_Event):
logger.warning(f'{color(f"--- Ignoring event {event}", "red")}')
def on_hci_command_complete_event(self, event):
def on_hci_command_complete_event(self, event: hci.HCI_Command_Complete_Event):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
@@ -931,7 +940,7 @@ class Host(utils.EventEmitter):
return self.on_command_processed(event)
def on_hci_command_status_event(self, event):
def on_hci_command_status_event(self, event: hci.HCI_Command_Status_Event):
return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(
@@ -951,7 +960,7 @@ class Host(utils.EventEmitter):
)
# Classic only
def on_hci_connection_request_event(self, event):
def on_hci_connection_request_event(self, event: hci.HCI_Connection_Request_Event):
# Notify the listeners
self.emit(
'connection_request',
@@ -960,7 +969,14 @@ class Host(utils.EventEmitter):
event.link_type,
)
def on_hci_le_connection_complete_event(self, event):
def on_hci_le_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Check if this is a cancellation
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
@@ -1006,15 +1022,25 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_le_enhanced_connection_complete_event(self, event):
def on_hci_le_enhanced_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Just use the same implementation as for the non-enhanced event for now
self.on_hci_le_connection_complete_event(event)
def on_hci_le_enhanced_connection_complete_v2_event(self, event):
def on_hci_le_enhanced_connection_complete_v2_event(
self, event: hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
):
# Just use the same implementation as for the v1 event for now
self.on_hci_le_enhanced_connection_complete_event(event)
def on_hci_connection_complete_event(self, event):
def on_hci_connection_complete_event(
self, event: hci.HCI_Connection_Complete_Event
):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
@@ -1054,7 +1080,9 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_disconnection_complete_event(self, event):
def on_hci_disconnection_complete_event(
self, event: hci.HCI_Disconnection_Complete_Event
):
# Find the connection
handle = event.connection_handle
if (
@@ -1093,7 +1121,9 @@ class Host(utils.EventEmitter):
# Notify the listeners
self.emit('disconnection_failure', handle, event.status)
def on_hci_le_connection_update_complete_event(self, event):
def on_hci_le_connection_update_complete_event(
self, event: hci.HCI_LE_Connection_Update_Complete_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
return
@@ -1113,7 +1143,9 @@ class Host(utils.EventEmitter):
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event):
def on_hci_le_phy_update_complete_event(
self, event: hci.HCI_LE_PHY_Update_Complete_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle')
return
@@ -1143,7 +1175,9 @@ class Host(utils.EventEmitter):
):
self.on_hci_le_advertising_report_event(event)
def on_hci_le_advertising_set_terminated_event(self, event):
def on_hci_le_advertising_set_terminated_event(
self, event: hci.HCI_LE_Advertising_Set_Terminated_Event
):
self.emit(
'advertising_set_termination',
event.status,
@@ -1152,7 +1186,9 @@ class Host(utils.EventEmitter):
event.num_completed_extended_advertising_events,
)
def on_hci_le_periodic_advertising_sync_established_event(self, event):
def on_hci_le_periodic_advertising_sync_established_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_Event
):
self.emit(
'periodic_advertising_sync_establishment',
event.status,
@@ -1164,16 +1200,22 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_lost_event(self, event):
def on_hci_le_periodic_advertising_sync_lost_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event
):
self.emit('periodic_advertising_sync_loss', event.sync_handle)
def on_hci_le_periodic_advertising_report_event(self, event):
def on_hci_le_periodic_advertising_report_event(
self, event: hci.HCI_LE_Periodic_Advertising_Report_Event
):
self.emit('periodic_advertising_report', event.sync_handle, event)
def on_hci_le_biginfo_advertising_report_event(self, event):
def on_hci_le_biginfo_advertising_report_event(
self, event: hci.HCI_LE_BIGInfo_Advertising_Report_Event
):
self.emit('biginfo_advertising_report', event.sync_handle, event)
def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_request_event(self, event: hci.HCI_LE_CIS_Request_Event):
self.emit(
'cis_request',
event.acl_connection_handle,
@@ -1182,10 +1224,12 @@ class Host(utils.EventEmitter):
event.cis_id,
)
def on_hci_le_create_big_complete_event(self, event):
def on_hci_le_create_big_complete_event(
self, event: hci.HCI_LE_Create_BIG_Complete_Event
):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
logger.warning("BIS established but ISO packets not supported")
raise InvalidStateError("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
@@ -1208,8 +1252,13 @@ class Host(utils.EventEmitter):
event.iso_interval,
)
def on_hci_le_big_sync_established_event(self, event):
def on_hci_le_big_sync_established_event(
self, event: hci.HCI_LE_BIG_Sync_Established_Event
):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
raise InvalidStateError("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
@@ -1229,15 +1278,19 @@ class Host(utils.EventEmitter):
event.connection_handle,
)
def on_hci_le_big_sync_lost_event(self, event):
def on_hci_le_big_sync_lost_event(self, event: hci.HCI_LE_BIG_Sync_Lost_Event):
self.remove_big(event.big_handle)
self.emit('big_sync_lost', event.big_handle, event.reason)
def on_hci_le_terminate_big_complete_event(self, event):
def on_hci_le_terminate_big_complete_event(
self, event: hci.HCI_LE_Terminate_BIG_Complete_Event
):
self.remove_big(event.big_handle)
self.emit('big_termination', event.reason, event.big_handle)
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
def on_hci_le_periodic_advertising_sync_transfer_received_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event
):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
@@ -1250,7 +1303,9 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_V2_Event
):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
@@ -1263,11 +1318,11 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_cis_established_event(self, event):
def on_hci_le_cis_established_event(self, event: hci.HCI_LE_CIS_Established_Event):
# The remaining parameters are unused for now.
if event.status == hci.HCI_SUCCESS:
if self.iso_packet_queue is None:
logger.warning("CIS established but ISO packets not supported")
raise InvalidStateError("CIS established but ISO packets not supported")
self.cis_links[event.connection_handle] = IsoLink(
handle=event.connection_handle, packet_queue=self.iso_packet_queue
)
@@ -1294,7 +1349,9 @@ class Host(utils.EventEmitter):
'cis_establishment_failure', event.connection_handle, event.status
)
def on_hci_le_remote_connection_parameter_request_event(self, event):
def on_hci_le_remote_connection_parameter_request_event(
self, event: hci.HCI_LE_Remote_Connection_Parameter_Request_Event
):
if event.connection_handle not in self.connections:
logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle')
return
@@ -1313,7 +1370,9 @@ class Host(utils.EventEmitter):
)
)
def on_hci_le_long_term_key_request_event(self, event):
def on_hci_le_long_term_key_request_event(
self, event: hci.HCI_LE_Long_Term_Key_Request_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle')
return
@@ -1347,7 +1406,9 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_long_term_key())
def on_hci_synchronous_connection_complete_event(self, event):
def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event
):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
@@ -1373,7 +1434,9 @@ class Host(utils.EventEmitter):
# Notify the client
self.emit('sco_connection_failure', event.bd_addr, event.status)
def on_hci_synchronous_connection_changed_event(self, event):
def on_hci_synchronous_connection_changed_event(
self, event: hci.HCI_Synchronous_Connection_Changed_Event
):
pass
def on_hci_mode_change_event(self, event: hci.HCI_Mode_Change_Event):
@@ -1385,7 +1448,7 @@ class Host(utils.EventEmitter):
event.interval,
)
def on_hci_role_change_event(self, event):
def on_hci_role_change_event(self, event: hci.HCI_Role_Change_Event):
if event.status == hci.HCI_SUCCESS:
logger.debug(
f'role change for {event.bd_addr}: '
@@ -1399,7 +1462,9 @@ class Host(utils.EventEmitter):
)
self.emit('role_change_failure', event.bd_addr, event.status)
def on_hci_le_data_length_change_event(self, event):
def on_hci_le_data_length_change_event(
self, event: hci.HCI_LE_Data_Length_Change_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! DATA LENGTH CHANGE: unknown handle')
return
@@ -1413,7 +1478,9 @@ class Host(utils.EventEmitter):
event.max_rx_time,
)
def on_hci_authentication_complete_event(self, event):
def on_hci_authentication_complete_event(
self, event: hci.HCI_Authentication_Complete_Event
):
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle)
@@ -1454,7 +1521,9 @@ class Host(utils.EventEmitter):
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event):
def on_hci_encryption_key_refresh_complete_event(
self, event: hci.HCI_Encryption_Key_Refresh_Complete_Event
):
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle)
@@ -1465,7 +1534,7 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_qos_setup_complete_event(self, event):
def on_hci_qos_setup_complete_event(self, event: hci.HCI_QOS_Setup_Complete_Event):
if event.status == hci.HCI_SUCCESS:
self.emit(
'connection_qos_setup', event.connection_handle, event.service_type
@@ -1477,23 +1546,31 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
def on_hci_link_supervision_timeout_changed_event(
self, event: hci.HCI_Link_Supervision_Timeout_Changed_Event
):
pass
def on_hci_max_slots_change_event(self, event):
def on_hci_max_slots_change_event(self, event: hci.HCI_Max_Slots_Change_Event):
pass
def on_hci_page_scan_repetition_mode_change_event(self, event):
def on_hci_page_scan_repetition_mode_change_event(
self, event: hci.HCI_Page_Scan_Repetition_Mode_Change_Event
):
pass
def on_hci_link_key_notification_event(self, event):
def on_hci_link_key_notification_event(
self, event: hci.HCI_Link_Key_Notification_Event
):
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
def on_hci_simple_pairing_complete_event(
self, event: hci.HCI_Simple_Pairing_Complete_Event
):
logger.debug(
f'simple pairing complete for {event.bd_addr}: '
f'status={hci.HCI_Constant.status_name(event.status)}'
@@ -1503,10 +1580,10 @@ class Host(utils.EventEmitter):
else:
self.emit('classic_pairing_failure', event.bd_addr, event.status)
def on_hci_pin_code_request_event(self, event):
def on_hci_pin_code_request_event(self, event: hci.HCI_PIN_Code_Request_Event):
self.emit('pin_code_request', event.bd_addr)
def on_hci_link_key_request_event(self, event):
def on_hci_link_key_request_event(self, event: hci.HCI_Link_Key_Request_Event):
async def send_link_key():
if self.link_key_provider is None:
logger.debug('no link key provider')
@@ -1531,10 +1608,14 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_link_key())
def on_hci_io_capability_request_event(self, event):
def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event
):
self.emit('authentication_io_capability_request', event.bd_addr)
def on_hci_io_capability_response_event(self, event):
def on_hci_io_capability_response_event(
self, event: hci.HCI_IO_Capability_Response_Event
):
self.emit(
'authentication_io_capability_response',
event.bd_addr,
@@ -1542,25 +1623,33 @@ class Host(utils.EventEmitter):
event.authentication_requirements,
)
def on_hci_user_confirmation_request_event(self, event):
def on_hci_user_confirmation_request_event(
self, event: hci.HCI_User_Confirmation_Request_Event
):
self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event):
def on_hci_user_passkey_request_event(
self, event: hci.HCI_User_Passkey_Request_Event
):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
def on_hci_user_passkey_notification_event(
self, event: hci.HCI_User_Passkey_Notification_Event
):
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, _event):
def on_hci_inquiry_complete_event(self, _event: hci.HCI_Inquiry_Complete_Event):
self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event):
def on_hci_inquiry_result_with_rssi_event(
self, event: hci.HCI_Inquiry_Result_With_RSSI_Event
):
for bd_addr, class_of_device, rssi in zip(
event.bd_addr, event.class_of_device, event.rssi
):
@@ -1572,7 +1661,9 @@ class Host(utils.EventEmitter):
rssi,
)
def on_hci_extended_inquiry_result_event(self, event):
def on_hci_extended_inquiry_result_event(
self, event: hci.HCI_Extended_Inquiry_Result_Event
):
self.emit(
'inquiry_result',
event.bd_addr,
@@ -1581,7 +1672,9 @@ class Host(utils.EventEmitter):
event.rssi,
)
def on_hci_remote_name_request_complete_event(self, event):
def on_hci_remote_name_request_complete_event(
self, event: hci.HCI_Remote_Name_Request_Complete_Event
):
if event.status != hci.HCI_SUCCESS:
self.emit('remote_name_failure', event.bd_addr, event.status)
else:
@@ -1592,14 +1685,18 @@ class Host(utils.EventEmitter):
self.emit('remote_name', event.bd_addr, utf8_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
def on_hci_remote_host_supported_features_notification_event(
self, event: hci.HCI_Remote_Host_Supported_Features_Notification_Event
):
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)
def on_hci_le_read_remote_features_complete_event(self, event):
def on_hci_le_read_remote_features_complete_event(
self, event: hci.HCI_LE_Read_Remote_Features_Complete_Event
):
if event.status != hci.HCI_SUCCESS:
self.emit(
'le_remote_features_failure', event.connection_handle, event.status
@@ -1611,22 +1708,34 @@ class Host(utils.EventEmitter):
int.from_bytes(event.le_features, 'little'),
)
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(
self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event
):
self.emit('cs_remote_supported_capabilities', event)
def on_hci_le_cs_security_enable_complete_event(self, event):
def on_hci_le_cs_security_enable_complete_event(
self, event: hci.HCI_LE_CS_Security_Enable_Complete_Event
):
self.emit('cs_security', event)
def on_hci_le_cs_config_complete_event(self, event):
def on_hci_le_cs_config_complete_event(
self, event: hci.HCI_LE_CS_Config_Complete_Event
):
self.emit('cs_config', event)
def on_hci_le_cs_procedure_enable_complete_event(self, event):
def on_hci_le_cs_procedure_enable_complete_event(
self, event: hci.HCI_LE_CS_Procedure_Enable_Complete_Event
):
self.emit('cs_procedure', event)
def on_hci_le_cs_subevent_result_event(self, event):
def on_hci_le_cs_subevent_result_event(
self, event: hci.HCI_LE_CS_Subevent_Result_Event
):
self.emit('cs_subevent_result', event)
def on_hci_le_cs_subevent_result_continue_event(self, event):
def on_hci_le_cs_subevent_result_continue_event(
self, event: hci.HCI_LE_CS_Subevent_Result_Continue_Event
):
self.emit('cs_subevent_result_continue', event)
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
@@ -1639,5 +1748,5 @@ class Host(utils.EventEmitter):
event.supervision_timeout,
)
def on_hci_vendor_event(self, event):
def on_hci_vendor_event(self, event: hci.HCI_Vendor_Event):
self.emit('vendor_event', event)

View File

@@ -18,7 +18,6 @@
from __future__ import annotations
import asyncio
import functools
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Union
@@ -272,7 +271,7 @@ class HearingAccessService(gatt.TemplateService):
def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
@@ -373,8 +372,7 @@ class HearingAccessService(gatt.TemplateService):
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
][:num_presets]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
@@ -383,7 +381,10 @@ class HearingAccessService(gatt.TemplateService):
async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
@@ -404,7 +405,7 @@ class HearingAccessService(gatt.TemplateService):
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
await self._notify_preset_operations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
@@ -413,14 +414,14 @@ class HearingAccessService(gatt.TemplateService):
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
@@ -432,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
@@ -447,8 +448,10 @@ class HearingAccessService(gatt.TemplateService):
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Preset Changed operation aborted and
# shall continue the operation when the client reconnects.
while op_list:
try:
await connection.device.indicate_subscriber(
connection,
@@ -460,14 +463,15 @@ class HearingAccessService(gatt.TemplateService):
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
for history_list in self.preset_changed_operations_history_per_device.values():
history_list.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes):
del connection # Unused
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -532,48 +536,51 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(self, _: Connection, value: bytes):
async def _on_set_active_preset(self, connection: Connection, value: bytes):
del connection # Unused
await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous):
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
'''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
presets = sorted(
[
record
for record in self.preset_records.values()
if record.is_available()
],
key=lambda record: record.index,
)
current_preset = self.preset_records[self.active_preset_index]
current_preset_pos = presets.index(current_preset)
if is_previous:
new_preset = presets[(current_preset_pos - 1) % len(presets)]
else:
new_preset = presets[(current_preset_pos + 1) % len(presets)]
if not first_preset: # If no other preset are available
if current_preset == new_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
self.active_preset_index = new_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_previous_preset(
self, connection: Connection, value: bytes
) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally(
self, _: Connection, value: bytes
self, connection: Connection, value: bytes
):
del connection # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -584,8 +591,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -596,8 +604,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -615,11 +624,13 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue
preset_control_point_indications: asyncio.Queue[bytes]
active_preset_index_notification: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
@@ -641,20 +652,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
async def setup_subscription(self) -> None:
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
self.preset_control_point_indications.put_nowait,
prefer_notify=False,
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
self.active_preset_index_notification.put_nowait
)

View File

@@ -15,67 +15,261 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
from __future__ import annotations
import struct
from collections.abc import Sequence
import pytest
from bumble import avc, avctp, avrcp, controller, core, device, host, link
from bumble.transport import common
from bumble import avc, avctp, avrcp
from . import test_utils
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = link.LocalLink()
self.controllers = [
controller.Controller('C1', link=self.link, public_address=addresses[0]),
controller.Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
device.Device(
address=addresses[0],
host=host.Host(
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
),
),
device.Device(
address=addresses[1],
host=host.Host(
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
),
),
]
self.devices[0].classic_enabled = True
self.devices[1].classic_enabled = True
self.connections = [None, None]
self.protocols = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
async def setup_connections(self):
await self.devices[0].power_on()
await self.devices[1].power_on()
self.connections = await asyncio.gather(
self.devices[0].connect(
self.devices[1].public_address, core.PhysicalTransport.BR_EDR
),
self.devices[1].accept(self.devices[0].public_address),
)
class TwoDevices(test_utils.TwoDevices):
protocols: Sequence[avrcp.Protocol] = ()
async def setup_avdtp_connections(self):
self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
self.protocols[0].listen(self.devices[1])
await self.protocols[1].connect(self.connections[0])
@classmethod
async def create_with_avdtp(cls) -> TwoDevices:
devices = await cls.create_with_connection()
await devices.setup_avdtp_connections()
return devices
@pytest.mark.parametrize(
"command,",
[
avrcp.GetPlayStatusCommand(),
avrcp.GetCapabilitiesCommand(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID
),
avrcp.SetAbsoluteVolumeCommand(volume=5),
avrcp.GetElementAttributesCommand(
identifier=999,
attribute_ids=[
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.ARTIST_NAME,
],
),
avrcp.RegisterNotificationCommand(
event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123
),
avrcp.SearchCommand(
character_set_id=avrcp.CharacterSetId.UTF_8, search_string="Bumble!"
),
avrcp.PlayItemCommand(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
),
avrcp.ListPlayerApplicationSettingAttributesCommand(),
avrcp.ListPlayerApplicationSettingValuesCommand(
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
),
avrcp.GetCurrentPlayerApplicationSettingValueCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.SetPlayerApplicationSettingValueCommand(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
),
avrcp.GetPlayerApplicationSettingAttributeTextCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.GetPlayerApplicationSettingValueTextCommand(
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
],
),
avrcp.InformDisplayableCharacterSetCommand(
character_set_id=[avrcp.CharacterSetId.UTF_8]
),
avrcp.InformBatteryStatusOfCtCommand(
battery_status=avrcp.InformBatteryStatusOfCtCommand.BatteryStatus.NORMAL
),
avrcp.SetAddressedPlayerCommand(player_id=1),
avrcp.SetBrowsedPlayerCommand(player_id=1),
avrcp.GetFolderItemsCommand(
scope=avrcp.Scope.NOW_PLAYING,
start_item=0,
end_item=1,
attributes=[avrcp.MediaAttributeId.ARTIST_NAME],
),
avrcp.ChangePathCommand(
uid_counter=1,
direction=avrcp.ChangePathCommand.Direction.DOWN,
folder_uid=2,
),
avrcp.GetItemAttributesCommand(
scope=avrcp.Scope.NOW_PLAYING,
uid=0,
uid_counter=1,
start_item=0,
end_item=0,
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
),
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
avrcp.AddToNowPlayingCommand(
scope=avrcp.Scope.NOW_PLAYING, uid=0, uid_counter=1
),
],
)
def test_command(command: avrcp.Command):
assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command
@pytest.mark.parametrize(
"event,",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(identifier=b'12356'),
avrcp.VolumeChangedEvent(volume=9),
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
avrcp.AddressedPlayerChangedEvent(
player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10)
),
avrcp.AvailablePlayersChangedEvent(),
avrcp.PlaybackPositionChangedEvent(playback_position=1314),
avrcp.NowPlayingContentChangedEvent(),
avrcp.PlayerApplicationSettingChangedEvent(
player_application_settings=[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
)
]
),
],
)
def test_event(event: avrcp.Event):
assert avrcp.Event.from_bytes(bytes(event)) == event
@pytest.mark.parametrize(
"response,",
[
avrcp.GetPlayStatusResponse(
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
),
avrcp.GetCapabilitiesResponse(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED,
capabilities=[
avrcp.EventId.ADDRESSED_PLAYER_CHANGED,
avrcp.EventId.BATT_STATUS_CHANGED,
],
),
avrcp.RegisterNotificationResponse(
event=avrcp.PlaybackPositionChangedEvent(playback_position=38)
),
avrcp.SetAbsoluteVolumeResponse(volume=99),
avrcp.GetElementAttributesResponse(
attributes=[
avrcp.MediaAttribute(
attribute_id=avrcp.MediaAttributeId.ALBUM_NAME,
attribute_value="White Album",
character_set_id=avrcp.CharacterSetId.UTF_8,
)
]
),
avrcp.ListPlayerApplicationSettingAttributesResponse(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.ListPlayerApplicationSettingValuesResponse(
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
]
),
avrcp.GetCurrentPlayerApplicationSettingValueResponse(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
),
avrcp.SetPlayerApplicationSettingValueResponse(),
avrcp.GetPlayerApplicationSettingAttributeTextResponse(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
character_set_id=[avrcp.CharacterSetId.UTF_8],
attribute_string=["Repeat"],
),
avrcp.GetPlayerApplicationSettingValueTextResponse(
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
character_set_id=[avrcp.CharacterSetId.UTF_8],
attribute_string=["All track repeat"],
),
avrcp.InformDisplayableCharacterSetResponse(),
avrcp.InformBatteryStatusOfCtResponse(),
avrcp.SetAddressedPlayerResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
avrcp.SetBrowsedPlayerResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
numbers_of_items=2,
character_set_id=avrcp.CharacterSetId.UTF_8,
folder_names=["folder1", "folder2"],
),
avrcp.GetFolderItemsResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
items=[
avrcp.MediaPlayerItem(
player_id=1,
major_player_type=avrcp.MediaPlayerItem.MajorPlayerType.AUDIO,
player_sub_type=avrcp.MediaPlayerItem.PlayerSubType.AUDIO_BOOK,
play_status=avrcp.PlayStatus.FWD_SEEK,
feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Woo",
)
],
),
avrcp.ChangePathResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED, number_of_items=2
),
avrcp.GetItemAttributesResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
attribute_value_entry_list=[
avrcp.AttributeValueEntry(
attribute_id=avrcp.MediaAttributeId.GENRE,
character_set_id=avrcp.CharacterSetId.UTF_8,
attribute_value="uuddlrlrabab",
)
],
),
avrcp.GetTotalNumberOfItemsResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
number_of_items=2,
),
avrcp.SearchResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
number_of_items=2,
),
avrcp.PlayItemResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
avrcp.AddToNowPlayingResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
],
)
def test_response(response: avrcp.Response):
assert avrcp.Response.from_bytes(bytes(response), response.pdu_id) == response
# -----------------------------------------------------------------------------
def test_frame_parser():
with pytest.raises(ValueError) as error:
with pytest.raises(ValueError):
avc.Frame.from_bytes(bytes.fromhex("11480000"))
x = bytes.fromhex("014D0208")
@@ -217,8 +411,7 @@ def test_passthrough_commands():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_events():
two_devices = TwoDevices()
await two_devices.setup_connections()
two_devices = await TwoDevices.create_with_avdtp()
supported_events = await two_devices.protocols[0].get_supported_events()
assert supported_events == []