forked from auracaster/bumble_mirror
Compare commits
17 Commits
gbg/data-t
...
revert-771
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca23d6b89a | ||
|
|
d86d69d816 | ||
|
|
dc93f32a9a | ||
|
|
9838908a26 | ||
|
|
613519f0b3 | ||
|
|
a943ea57ef | ||
|
|
14401910bb | ||
|
|
5d35ed471c | ||
|
|
c720ad5fdc | ||
|
|
f02183f95d | ||
|
|
d903937a51 | ||
|
|
6381ee0ab1 | ||
|
|
59d99780e1 | ||
|
|
9f3d8c9b49 | ||
|
|
31961febe5 | ||
|
|
dab0993cba | ||
|
|
3333ba472b |
1735
bumble/avrcp.py
1735
bumble/avrcp.py
File diff suppressed because it is too large
Load Diff
170
bumble/device.py
170
bumble/device.py
@@ -48,6 +48,7 @@ from typing_extensions import Self
|
||||
from bumble import (
|
||||
core,
|
||||
data_types,
|
||||
gatt,
|
||||
gatt_client,
|
||||
gatt_server,
|
||||
hci,
|
||||
@@ -264,7 +265,7 @@ class ExtendedAdvertisement(Advertisement):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AdvertisementDataAccumulator:
|
||||
def __init__(self, passive=False):
|
||||
def __init__(self, passive: bool = False):
|
||||
self.passive = passive
|
||||
self.last_advertisement = None
|
||||
self.last_data = b''
|
||||
@@ -1243,7 +1244,7 @@ class LePhyOptions:
|
||||
PREFER_S_2_CODED_PHY = 1
|
||||
PREFER_S_8_CODED_PHY = 2
|
||||
|
||||
def __init__(self, coded_phy_preference=0):
|
||||
def __init__(self, coded_phy_preference: int = 0):
|
||||
self.coded_phy_preference = coded_phy_preference
|
||||
|
||||
def __int__(self):
|
||||
@@ -1691,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter):
|
||||
self_address: hci.Address
|
||||
self_resolvable_address: Optional[hci.Address]
|
||||
peer_address: hci.Address
|
||||
peer_name: Optional[str]
|
||||
peer_resolvable_address: Optional[hci.Address]
|
||||
peer_le_features: Optional[hci.LeFeatureMask]
|
||||
role: hci.Role
|
||||
@@ -1931,7 +1933,7 @@ class Connection(utils.CompositeEventEmitter):
|
||||
self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result)
|
||||
self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception)
|
||||
|
||||
async def set_data_length(self, tx_octets, tx_time) -> None:
|
||||
async def set_data_length(self, tx_octets: int, tx_time: int) -> None:
|
||||
return await self.device.set_data_length(self, tx_octets, tx_time)
|
||||
|
||||
async def update_parameters(
|
||||
@@ -1961,7 +1963,12 @@ class Connection(utils.CompositeEventEmitter):
|
||||
use_l2cap=use_l2cap,
|
||||
)
|
||||
|
||||
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
|
||||
async def set_phy(
|
||||
self,
|
||||
tx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
rx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
phy_options: int = 0,
|
||||
):
|
||||
return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options)
|
||||
|
||||
async def get_phy(self) -> ConnectionPHY:
|
||||
@@ -2156,7 +2163,7 @@ class DeviceConfiguration:
|
||||
# Decorator that converts the first argument from a connection handle to a connection
|
||||
def with_connection_from_handle(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, connection_handle, *args, **kwargs):
|
||||
def wrapper(self, connection_handle: int, *args, **kwargs):
|
||||
if (connection := self.lookup_connection(connection_handle)) is None:
|
||||
raise ObjectLookupError(
|
||||
f'no connection for handle: 0x{connection_handle:04x}'
|
||||
@@ -2169,7 +2176,7 @@ def with_connection_from_handle(function):
|
||||
# Decorator that converts the first argument from a bluetooth address to a connection
|
||||
def with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
def wrapper(self, address: hci.Address, *args, **kwargs):
|
||||
if connection := self.pending_connections.get(address, False):
|
||||
return function(self, connection, *args, **kwargs)
|
||||
for connection in self.connections.values():
|
||||
@@ -2655,7 +2662,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||
|
||||
async def send_command(self, command, check_result=False):
|
||||
async def send_command(self, command: hci.HCI_Command, check_result: bool = False):
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.host.send_command(command, check_result), self.command_timeout
|
||||
@@ -2912,13 +2919,13 @@ class Device(utils.CompositeEventEmitter):
|
||||
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
|
||||
return self.host.supports_le_features(feature)
|
||||
|
||||
def supports_le_phy(self, phy: int) -> bool:
|
||||
if phy == hci.HCI_LE_1M_PHY:
|
||||
def supports_le_phy(self, phy: hci.Phy) -> bool:
|
||||
if phy == hci.Phy.LE_1M:
|
||||
return True
|
||||
|
||||
feature_map: dict[int, hci.LeFeatureMask] = {
|
||||
hci.HCI_LE_2M_PHY: hci.LeFeatureMask.LE_2M_PHY,
|
||||
hci.HCI_LE_CODED_PHY: hci.LeFeatureMask.LE_CODED_PHY,
|
||||
feature_map: dict[hci.Phy, hci.LeFeatureMask] = {
|
||||
hci.Phy.LE_2M: hci.LeFeatureMask.LE_2M_PHY,
|
||||
hci.Phy.LE_CODED: hci.LeFeatureMask.LE_CODED_PHY,
|
||||
}
|
||||
if phy not in feature_map:
|
||||
raise InvalidArgumentError('invalid PHY')
|
||||
@@ -3522,7 +3529,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
self.discovering = False
|
||||
|
||||
@host_event_handler
|
||||
def on_inquiry_result(self, address, class_of_device, data, rssi):
|
||||
def on_inquiry_result(
|
||||
self, address: hci.Address, class_of_device: int, data: bytes, rssi: int
|
||||
):
|
||||
self.emit(
|
||||
self.EVENT_INQUIRY_RESULT,
|
||||
address,
|
||||
@@ -3531,7 +3540,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
rssi,
|
||||
)
|
||||
|
||||
async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled):
|
||||
async def set_scan_enable(
|
||||
self, inquiry_scan_enabled: bool, page_scan_enabled: bool
|
||||
):
|
||||
if inquiry_scan_enabled and page_scan_enabled:
|
||||
scan_enable = 0x03
|
||||
elif page_scan_enabled:
|
||||
@@ -3657,6 +3668,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
always_resolve = False
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, transport
|
||||
) # TODO: timeout
|
||||
@@ -3684,7 +3696,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error):
|
||||
def on_connection_failure(error: core.ConnectionError):
|
||||
if transport == PhysicalTransport.LE or (
|
||||
# match BR/EDR connection failure event against peer address
|
||||
error.transport == transport
|
||||
@@ -3904,6 +3916,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
except InvalidArgumentError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, PhysicalTransport.BR_EDR
|
||||
) # TODO: timeout
|
||||
@@ -3962,7 +3975,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error):
|
||||
def on_connection_failure(error: core.ConnectionError):
|
||||
if (
|
||||
error.transport == PhysicalTransport.BR_EDR
|
||||
and error.peer_address == peer_address
|
||||
@@ -3999,7 +4012,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_as_gatt(self, peer_address):
|
||||
async def connect_as_gatt(self, peer_address: Union[hci.Address, str]):
|
||||
async with AsyncExitStack() as stack:
|
||||
connection = await stack.enter_async_context(
|
||||
await self.connect(peer_address)
|
||||
@@ -4035,6 +4048,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
except InvalidArgumentError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, PhysicalTransport.BR_EDR
|
||||
) # TODO: timeout
|
||||
@@ -4080,7 +4094,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
)
|
||||
self.disconnecting = False
|
||||
|
||||
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
|
||||
async def set_data_length(
|
||||
self, connection: Connection, tx_octets: int, tx_time: int
|
||||
) -> None:
|
||||
if tx_octets < 0x001B or tx_octets > 0x00FB:
|
||||
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
|
||||
|
||||
@@ -4183,7 +4199,11 @@ class Device(utils.CompositeEventEmitter):
|
||||
)
|
||||
|
||||
async def set_connection_phy(
|
||||
self, connection, tx_phys=None, rx_phys=None, phy_options=None
|
||||
self,
|
||||
connection: Connection,
|
||||
tx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
rx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
phy_options: int = 0,
|
||||
):
|
||||
if not self.host.supports_command(hci.HCI_LE_SET_PHY_COMMAND):
|
||||
logger.warning('ignoring request, command not supported')
|
||||
@@ -4199,7 +4219,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
all_phys=all_phys_bits,
|
||||
tx_phys=hci.phy_list_to_bits(tx_phys),
|
||||
rx_phys=hci.phy_list_to_bits(rx_phys),
|
||||
phy_options=0 if phy_options is None else int(phy_options),
|
||||
phy_options=phy_options,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4210,7 +4230,11 @@ class Device(utils.CompositeEventEmitter):
|
||||
)
|
||||
raise hci.HCI_StatusError(result)
|
||||
|
||||
async def set_default_phy(self, tx_phys=None, rx_phys=None):
|
||||
async def set_default_phy(
|
||||
self,
|
||||
tx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
rx_phys: Optional[Iterable[hci.Phy]] = None,
|
||||
):
|
||||
all_phys_bits = (1 if tx_phys is None else 0) | (
|
||||
(1 if rx_phys is None else 0) << 1
|
||||
)
|
||||
@@ -4248,7 +4272,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
async def find_peer_by_name(self, name, transport=PhysicalTransport.LE):
|
||||
async def find_peer_by_name(self, name: str, transport=PhysicalTransport.LE):
|
||||
"""
|
||||
Scan for a peer with a given name and return its address.
|
||||
"""
|
||||
@@ -4263,7 +4287,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
if local_name == name:
|
||||
peer_address.set_result(address)
|
||||
|
||||
listener = None
|
||||
listener: Optional[Callable[..., None]] = None
|
||||
was_scanning = self.scanning
|
||||
was_discovering = self.discovering
|
||||
try:
|
||||
@@ -4369,10 +4393,10 @@ class Device(utils.CompositeEventEmitter):
|
||||
def smp_session_proxy(self, session_proxy: type[smp.Session]) -> None:
|
||||
self.smp_manager.session_proxy = session_proxy
|
||||
|
||||
async def pair(self, connection):
|
||||
async def pair(self, connection: Connection):
|
||||
return await self.smp_manager.pair(connection)
|
||||
|
||||
def request_pairing(self, connection):
|
||||
def request_pairing(self, connection: Connection):
|
||||
return self.smp_manager.request_pairing(connection)
|
||||
|
||||
async def get_long_term_key(
|
||||
@@ -4460,7 +4484,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
on_authentication_failure,
|
||||
)
|
||||
|
||||
async def encrypt(self, connection, enable=True):
|
||||
async def encrypt(self, connection: Connection, enable: bool = True):
|
||||
if not enable and connection.transport == PhysicalTransport.LE:
|
||||
raise InvalidArgumentError('`enable` parameter is classic only.')
|
||||
|
||||
@@ -4470,7 +4494,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
def on_encryption_change():
|
||||
pending_encryption.set_result(None)
|
||||
|
||||
def on_encryption_failure(error_code):
|
||||
def on_encryption_failure(error_code: int):
|
||||
pending_encryption.set_exception(hci.HCI_Error(error_code))
|
||||
|
||||
connection.on(
|
||||
@@ -4562,10 +4586,10 @@ class Device(utils.CompositeEventEmitter):
|
||||
async def switch_role(self, connection: Connection, role: hci.Role):
|
||||
pending_role_change = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_role_change(new_role):
|
||||
def on_role_change(new_role: hci.Role):
|
||||
pending_role_change.set_result(new_role)
|
||||
|
||||
def on_role_change_failure(error_code):
|
||||
def on_role_change_failure(error_code: int):
|
||||
pending_role_change.set_exception(hci.HCI_Error(error_code))
|
||||
|
||||
connection.on(connection.EVENT_ROLE_CHANGE, on_role_change)
|
||||
@@ -5155,10 +5179,10 @@ class Device(utils.CompositeEventEmitter):
|
||||
):
|
||||
connection.emit(connection.EVENT_LINK_KEY)
|
||||
|
||||
def add_service(self, service):
|
||||
def add_service(self, service: gatt.Service):
|
||||
self.gatt_server.add_service(service)
|
||||
|
||||
def add_services(self, services):
|
||||
def add_services(self, services: Iterable[gatt.Service]):
|
||||
self.gatt_server.add_services(services)
|
||||
|
||||
def add_default_services(
|
||||
@@ -5254,10 +5278,10 @@ class Device(utils.CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
def on_advertising_set_termination(
|
||||
self,
|
||||
status,
|
||||
advertising_handle,
|
||||
connection_handle,
|
||||
number_of_completed_extended_advertising_events,
|
||||
status: int,
|
||||
advertising_handle: int,
|
||||
connection_handle: int,
|
||||
number_of_completed_extended_advertising_events: int,
|
||||
):
|
||||
# Legacy advertising set is also one of extended advertising sets.
|
||||
if not (
|
||||
@@ -5556,7 +5580,12 @@ class Device(utils.CompositeEventEmitter):
|
||||
)
|
||||
|
||||
@host_event_handler
|
||||
def on_connection_failure(self, transport, peer_address, error_code):
|
||||
def on_connection_failure(
|
||||
self,
|
||||
transport: hci.PhysicalTransport,
|
||||
peer_address: hci.Address,
|
||||
error_code: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'*** Connection failed: {hci.HCI_Constant.error_name(error_code)}'
|
||||
)
|
||||
@@ -5675,7 +5704,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_authentication(self, connection):
|
||||
def on_connection_authentication(self, connection: Connection):
|
||||
logger.debug(
|
||||
f'*** Connection Authentication: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}'
|
||||
@@ -5685,7 +5714,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_authentication_failure(self, connection, error):
|
||||
def on_connection_authentication_failure(
|
||||
self, connection: Connection, error: core.ConnectionError
|
||||
):
|
||||
logger.debug(
|
||||
f'*** Connection Authentication Failure: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, error={error}'
|
||||
@@ -5727,10 +5758,13 @@ class Device(utils.CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_io_capability_response(
|
||||
self, connection, io_capability, authentication_requirements
|
||||
self,
|
||||
connection: Connection,
|
||||
io_capability: int,
|
||||
authentication_requirements: int,
|
||||
):
|
||||
connection.peer_pairing_io_capability = io_capability
|
||||
connection.peer_pairing_authentication_requirements = (
|
||||
connection.pairing_peer_io_capability = io_capability
|
||||
connection.pairing_peer_authentication_requirements = (
|
||||
authentication_requirements
|
||||
)
|
||||
|
||||
@@ -5741,7 +5775,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
io_capability = pairing_config.delegate.classic_io_capability
|
||||
peer_io_capability = connection.peer_pairing_io_capability
|
||||
peer_io_capability = connection.pairing_peer_io_capability
|
||||
|
||||
async def confirm() -> bool:
|
||||
# Ask the user to confirm the pairing, without display
|
||||
@@ -5816,7 +5850,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_passkey_request(self, connection) -> None:
|
||||
def on_authentication_user_passkey_request(self, connection: Connection) -> None:
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
|
||||
@@ -5859,7 +5893,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_pin_code_request(self, connection):
|
||||
def on_pin_code_request(self, connection: Connection):
|
||||
# Classic legacy pairing
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
@@ -5903,7 +5937,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_passkey_notification(self, connection, passkey):
|
||||
def on_authentication_user_passkey_notification(
|
||||
self, connection: Connection, passkey: int
|
||||
):
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
|
||||
@@ -5915,14 +5951,15 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@try_with_connection_from_address
|
||||
def on_remote_name(self, connection: Connection, address, remote_name):
|
||||
def on_remote_name(
|
||||
self, connection: Connection, address: hci.Address, remote_name: bytes
|
||||
):
|
||||
# Try to decode the name
|
||||
try:
|
||||
remote_name = remote_name.decode('utf-8')
|
||||
if connection:
|
||||
connection.peer_name = remote_name
|
||||
connection.peer_name = remote_name.decode('utf-8')
|
||||
connection.emit(connection.EVENT_REMOTE_NAME)
|
||||
self.emit(self.EVENT_REMOTE_NAME, address, remote_name)
|
||||
self.emit(self.EVENT_REMOTE_NAME, address, remote_name.decode('utf-8'))
|
||||
except UnicodeDecodeError as error:
|
||||
logger.warning('peer name is not valid UTF-8')
|
||||
if connection:
|
||||
@@ -5933,7 +5970,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@try_with_connection_from_address
|
||||
def on_remote_name_failure(self, connection: Connection, address, error):
|
||||
def on_remote_name_failure(
|
||||
self, connection: Connection, address: hci.Address, error: int
|
||||
):
|
||||
if connection:
|
||||
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
|
||||
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
|
||||
@@ -6134,7 +6173,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_encryption_key_refresh(self, connection):
|
||||
def on_connection_encryption_key_refresh(self, connection: Connection):
|
||||
logger.debug(
|
||||
f'*** Connection Key Refresh: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}'
|
||||
@@ -6172,7 +6211,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_parameters_update_failure(self, connection, error):
|
||||
def on_connection_parameters_update_failure(
|
||||
self, connection: Connection, error: int
|
||||
):
|
||||
logger.debug(
|
||||
f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, '
|
||||
@@ -6182,7 +6223,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_phy_update(self, connection, phy):
|
||||
def on_connection_phy_update(self, connection: Connection, phy: core.ConnectionPHY):
|
||||
logger.debug(
|
||||
f'*** Connection PHY Update: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, '
|
||||
@@ -6192,7 +6233,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_phy_update_failure(self, connection, error):
|
||||
def on_connection_phy_update_failure(self, connection: Connection, error: int):
|
||||
logger.debug(
|
||||
f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, '
|
||||
@@ -6221,7 +6262,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_att_mtu_update(self, connection, att_mtu):
|
||||
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
|
||||
logger.debug(
|
||||
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, '
|
||||
@@ -6233,7 +6274,12 @@ class Device(utils.CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_data_length_change(
|
||||
self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time
|
||||
self,
|
||||
connection: Connection,
|
||||
max_tx_octets: int,
|
||||
max_tx_time: int,
|
||||
max_rx_octets: int,
|
||||
max_rx_time: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'*** Connection Data Length Change: [0x{connection.handle:04X}] '
|
||||
@@ -6358,14 +6404,16 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_role_change(self, connection, new_role):
|
||||
def on_role_change(self, connection: Connection, new_role: hci.Role):
|
||||
connection.role = new_role
|
||||
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@try_with_connection_from_address
|
||||
def on_role_change_failure(self, connection, address, error):
|
||||
def on_role_change_failure(
|
||||
self, connection: Connection, address: hci.Address, error: int
|
||||
):
|
||||
if connection:
|
||||
connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error)
|
||||
self.emit(self.EVENT_ROLE_CHANGE_FAILURE, address, error)
|
||||
@@ -6379,7 +6427,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_classic_pairing_failure(self, connection: Connection, status) -> None:
|
||||
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
|
||||
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
|
||||
|
||||
def on_pairing_start(self, connection: Connection) -> None:
|
||||
@@ -6403,7 +6451,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
connection.emit(connection.EVENT_PAIRING_FAILURE, reason)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_gatt_pdu(self, connection, pdu):
|
||||
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
|
||||
# Parse the L2CAP payload into an ATT PDU object
|
||||
att_pdu = ATT_PDU.from_bytes(pdu)
|
||||
|
||||
@@ -6425,7 +6473,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
connection.gatt_server.on_gatt_pdu(connection, att_pdu)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_smp_pdu(self, connection, pdu):
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes):
|
||||
self.smp_manager.on_smp_pdu(connection, pdu)
|
||||
|
||||
@host_event_handler
|
||||
|
||||
@@ -26,7 +26,17 @@ import secrets
|
||||
import struct
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Any, Callable, ClassVar, Iterable, Optional, TypeVar, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -111,23 +121,57 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
|
||||
class SpecableEnum(utils.OpenIntEnum):
|
||||
|
||||
@classmethod
|
||||
def type_spec(cls, size: int):
|
||||
return {'size': size, 'mapper': lambda x: cls(x).name}
|
||||
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
|
||||
return {
|
||||
'serializer': lambda x: x.to_bytes(size, byteorder),
|
||||
'parser': lambda data, offset: (
|
||||
offset + size,
|
||||
cls(int.from_bytes(data[offset : offset + size], byteorder)),
|
||||
),
|
||||
'mapper': lambda x: cls(x).name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
|
||||
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
|
||||
def type_metadata(
|
||||
cls,
|
||||
size: int,
|
||||
list_begin: bool = False,
|
||||
list_end: bool = False,
|
||||
byteorder: Literal['little', 'big'] = 'little',
|
||||
):
|
||||
return metadata(
|
||||
cls.type_spec(size, byteorder),
|
||||
list_begin=list_begin,
|
||||
list_end=list_end,
|
||||
)
|
||||
|
||||
|
||||
class SpecableFlag(enum.IntFlag):
|
||||
|
||||
@classmethod
|
||||
def type_spec(cls, size: int):
|
||||
return {'size': size, 'mapper': lambda x: cls(x).name}
|
||||
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
|
||||
return {
|
||||
'serializer': lambda x: x.to_bytes(size, byteorder),
|
||||
'parser': lambda data, offset: (
|
||||
offset + size,
|
||||
cls(int.from_bytes(data[offset : offset + size], byteorder)),
|
||||
),
|
||||
'mapper': lambda x: cls(x).name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
|
||||
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
|
||||
def type_metadata(
|
||||
cls,
|
||||
size: int,
|
||||
list_begin: bool = False,
|
||||
list_end: bool = False,
|
||||
byteorder: Literal['little', 'big'] = 'little',
|
||||
):
|
||||
return metadata(
|
||||
cls.type_spec(size, byteorder),
|
||||
list_begin=list_begin,
|
||||
list_end=list_end,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -6422,7 +6466,9 @@ class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event):
|
||||
irc: int = field(metadata=metadata(1))
|
||||
max_pdu: int = field(metadata=metadata(2))
|
||||
iso_interval: int = field(metadata=metadata(2))
|
||||
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
|
||||
connection_handle: Sequence[int] = field(
|
||||
metadata=metadata(2, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -6454,7 +6500,9 @@ class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event):
|
||||
irc: int = field(metadata=metadata(1))
|
||||
max_pdu: int = field(metadata=metadata(2))
|
||||
iso_interval: int = field(metadata=metadata(2))
|
||||
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
|
||||
connection_handle: Sequence[int] = field(
|
||||
metadata=metadata(2, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
237
bumble/host.py
237
bumble/host.py
@@ -22,11 +22,16 @@ import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast
|
||||
|
||||
from bumble import drivers, hci, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import ConnectionParameters, ConnectionPHY, PhysicalTransport
|
||||
from bumble.core import (
|
||||
ConnectionParameters,
|
||||
ConnectionPHY,
|
||||
InvalidStateError,
|
||||
PhysicalTransport,
|
||||
)
|
||||
from bumble.l2cap import L2CAP_PDU
|
||||
from bumble.snoop import Snooper
|
||||
from bumble.transport.common import TransportLostError
|
||||
@@ -902,10 +907,14 @@ class Host(utils.EventEmitter):
|
||||
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
self.emit('l2cap_pdu', connection.handle, cid, pdu)
|
||||
|
||||
def on_command_processed(self, event):
|
||||
def on_command_processed(
|
||||
self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event]
|
||||
):
|
||||
if self.pending_response:
|
||||
# Check that it is what we were expecting
|
||||
if self.pending_command.op_code != event.command_opcode:
|
||||
if self.pending_command is None:
|
||||
logger.warning('!!! pending_command is None ')
|
||||
elif self.pending_command.op_code != event.command_opcode:
|
||||
logger.warning(
|
||||
'!!! command result mismatch, expected '
|
||||
f'0x{self.pending_command.op_code:X} but got '
|
||||
@@ -919,10 +928,10 @@ class Host(utils.EventEmitter):
|
||||
############################################################
|
||||
# HCI handlers
|
||||
############################################################
|
||||
def on_hci_event(self, event):
|
||||
def on_hci_event(self, event: hci.HCI_Event):
|
||||
logger.warning(f'{color(f"--- Ignoring event {event}", "red")}')
|
||||
|
||||
def on_hci_command_complete_event(self, event):
|
||||
def on_hci_command_complete_event(self, event: hci.HCI_Command_Complete_Event):
|
||||
if event.command_opcode == 0:
|
||||
# This is used just for the Num_HCI_Command_Packets field, not related to
|
||||
# an actual command
|
||||
@@ -931,7 +940,7 @@ class Host(utils.EventEmitter):
|
||||
|
||||
return self.on_command_processed(event)
|
||||
|
||||
def on_hci_command_status_event(self, event):
|
||||
def on_hci_command_status_event(self, event: hci.HCI_Command_Status_Event):
|
||||
return self.on_command_processed(event)
|
||||
|
||||
def on_hci_number_of_completed_packets_event(
|
||||
@@ -951,7 +960,7 @@ class Host(utils.EventEmitter):
|
||||
)
|
||||
|
||||
# Classic only
|
||||
def on_hci_connection_request_event(self, event):
|
||||
def on_hci_connection_request_event(self, event: hci.HCI_Connection_Request_Event):
|
||||
# Notify the listeners
|
||||
self.emit(
|
||||
'connection_request',
|
||||
@@ -960,7 +969,14 @@ class Host(utils.EventEmitter):
|
||||
event.link_type,
|
||||
)
|
||||
|
||||
def on_hci_le_connection_complete_event(self, event):
|
||||
def on_hci_le_connection_complete_event(
|
||||
self,
|
||||
event: Union[
|
||||
hci.HCI_LE_Connection_Complete_Event,
|
||||
hci.HCI_LE_Enhanced_Connection_Complete_Event,
|
||||
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
|
||||
],
|
||||
):
|
||||
# Check if this is a cancellation
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
# Create/update the connection
|
||||
@@ -1006,15 +1022,25 @@ class Host(utils.EventEmitter):
|
||||
event.status,
|
||||
)
|
||||
|
||||
def on_hci_le_enhanced_connection_complete_event(self, event):
|
||||
def on_hci_le_enhanced_connection_complete_event(
|
||||
self,
|
||||
event: Union[
|
||||
hci.HCI_LE_Enhanced_Connection_Complete_Event,
|
||||
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
|
||||
],
|
||||
):
|
||||
# Just use the same implementation as for the non-enhanced event for now
|
||||
self.on_hci_le_connection_complete_event(event)
|
||||
|
||||
def on_hci_le_enhanced_connection_complete_v2_event(self, event):
|
||||
def on_hci_le_enhanced_connection_complete_v2_event(
|
||||
self, event: hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
|
||||
):
|
||||
# Just use the same implementation as for the v1 event for now
|
||||
self.on_hci_le_enhanced_connection_complete_event(event)
|
||||
|
||||
def on_hci_connection_complete_event(self, event):
|
||||
def on_hci_connection_complete_event(
|
||||
self, event: hci.HCI_Connection_Complete_Event
|
||||
):
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
# Create/update the connection
|
||||
logger.debug(
|
||||
@@ -1054,7 +1080,9 @@ class Host(utils.EventEmitter):
|
||||
event.status,
|
||||
)
|
||||
|
||||
def on_hci_disconnection_complete_event(self, event):
|
||||
def on_hci_disconnection_complete_event(
|
||||
self, event: hci.HCI_Disconnection_Complete_Event
|
||||
):
|
||||
# Find the connection
|
||||
handle = event.connection_handle
|
||||
if (
|
||||
@@ -1093,7 +1121,9 @@ class Host(utils.EventEmitter):
|
||||
# Notify the listeners
|
||||
self.emit('disconnection_failure', handle, event.status)
|
||||
|
||||
def on_hci_le_connection_update_complete_event(self, event):
|
||||
def on_hci_le_connection_update_complete_event(
|
||||
self, event: hci.HCI_LE_Connection_Update_Complete_Event
|
||||
):
|
||||
if (connection := self.connections.get(event.connection_handle)) is None:
|
||||
logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
|
||||
return
|
||||
@@ -1113,7 +1143,9 @@ class Host(utils.EventEmitter):
|
||||
'connection_parameters_update_failure', connection.handle, event.status
|
||||
)
|
||||
|
||||
def on_hci_le_phy_update_complete_event(self, event):
|
||||
def on_hci_le_phy_update_complete_event(
|
||||
self, event: hci.HCI_LE_PHY_Update_Complete_Event
|
||||
):
|
||||
if (connection := self.connections.get(event.connection_handle)) is None:
|
||||
logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle')
|
||||
return
|
||||
@@ -1143,7 +1175,9 @@ class Host(utils.EventEmitter):
|
||||
):
|
||||
self.on_hci_le_advertising_report_event(event)
|
||||
|
||||
def on_hci_le_advertising_set_terminated_event(self, event):
|
||||
def on_hci_le_advertising_set_terminated_event(
|
||||
self, event: hci.HCI_LE_Advertising_Set_Terminated_Event
|
||||
):
|
||||
self.emit(
|
||||
'advertising_set_termination',
|
||||
event.status,
|
||||
@@ -1152,7 +1186,9 @@ class Host(utils.EventEmitter):
|
||||
event.num_completed_extended_advertising_events,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_established_event(self, event):
|
||||
def on_hci_le_periodic_advertising_sync_established_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_Event
|
||||
):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_establishment',
|
||||
event.status,
|
||||
@@ -1164,16 +1200,22 @@ class Host(utils.EventEmitter):
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_lost_event(self, event):
|
||||
def on_hci_le_periodic_advertising_sync_lost_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event
|
||||
):
|
||||
self.emit('periodic_advertising_sync_loss', event.sync_handle)
|
||||
|
||||
def on_hci_le_periodic_advertising_report_event(self, event):
|
||||
def on_hci_le_periodic_advertising_report_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Report_Event
|
||||
):
|
||||
self.emit('periodic_advertising_report', event.sync_handle, event)
|
||||
|
||||
def on_hci_le_biginfo_advertising_report_event(self, event):
|
||||
def on_hci_le_biginfo_advertising_report_event(
|
||||
self, event: hci.HCI_LE_BIGInfo_Advertising_Report_Event
|
||||
):
|
||||
self.emit('biginfo_advertising_report', event.sync_handle, event)
|
||||
|
||||
def on_hci_le_cis_request_event(self, event):
|
||||
def on_hci_le_cis_request_event(self, event: hci.HCI_LE_CIS_Request_Event):
|
||||
self.emit(
|
||||
'cis_request',
|
||||
event.acl_connection_handle,
|
||||
@@ -1182,10 +1224,12 @@ class Host(utils.EventEmitter):
|
||||
event.cis_id,
|
||||
)
|
||||
|
||||
def on_hci_le_create_big_complete_event(self, event):
|
||||
def on_hci_le_create_big_complete_event(
|
||||
self, event: hci.HCI_LE_Create_BIG_Complete_Event
|
||||
):
|
||||
self.bigs[event.big_handle] = set(event.connection_handle)
|
||||
if self.iso_packet_queue is None:
|
||||
logger.warning("BIS established but ISO packets not supported")
|
||||
raise InvalidStateError("BIS established but ISO packets not supported")
|
||||
|
||||
for connection_handle in event.connection_handle:
|
||||
self.bis_links[connection_handle] = IsoLink(
|
||||
@@ -1208,8 +1252,13 @@ class Host(utils.EventEmitter):
|
||||
event.iso_interval,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_established_event(self, event):
|
||||
def on_hci_le_big_sync_established_event(
|
||||
self, event: hci.HCI_LE_BIG_Sync_Established_Event
|
||||
):
|
||||
self.bigs[event.big_handle] = set(event.connection_handle)
|
||||
if self.iso_packet_queue is None:
|
||||
raise InvalidStateError("BIS established but ISO packets not supported")
|
||||
|
||||
for connection_handle in event.connection_handle:
|
||||
self.bis_links[connection_handle] = IsoLink(
|
||||
connection_handle, self.iso_packet_queue
|
||||
@@ -1229,15 +1278,19 @@ class Host(utils.EventEmitter):
|
||||
event.connection_handle,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_lost_event(self, event):
|
||||
def on_hci_le_big_sync_lost_event(self, event: hci.HCI_LE_BIG_Sync_Lost_Event):
|
||||
self.remove_big(event.big_handle)
|
||||
self.emit('big_sync_lost', event.big_handle, event.reason)
|
||||
|
||||
def on_hci_le_terminate_big_complete_event(self, event):
|
||||
def on_hci_le_terminate_big_complete_event(
|
||||
self, event: hci.HCI_LE_Terminate_BIG_Complete_Event
|
||||
):
|
||||
self.remove_big(event.big_handle)
|
||||
self.emit('big_termination', event.reason, event.big_handle)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event
|
||||
):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_transfer',
|
||||
event.status,
|
||||
@@ -1250,7 +1303,9 @@ class Host(utils.EventEmitter):
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_V2_Event
|
||||
):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_transfer',
|
||||
event.status,
|
||||
@@ -1263,11 +1318,11 @@ class Host(utils.EventEmitter):
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_cis_established_event(self, event):
|
||||
def on_hci_le_cis_established_event(self, event: hci.HCI_LE_CIS_Established_Event):
|
||||
# The remaining parameters are unused for now.
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
if self.iso_packet_queue is None:
|
||||
logger.warning("CIS established but ISO packets not supported")
|
||||
raise InvalidStateError("CIS established but ISO packets not supported")
|
||||
self.cis_links[event.connection_handle] = IsoLink(
|
||||
handle=event.connection_handle, packet_queue=self.iso_packet_queue
|
||||
)
|
||||
@@ -1294,7 +1349,9 @@ class Host(utils.EventEmitter):
|
||||
'cis_establishment_failure', event.connection_handle, event.status
|
||||
)
|
||||
|
||||
def on_hci_le_remote_connection_parameter_request_event(self, event):
|
||||
def on_hci_le_remote_connection_parameter_request_event(
|
||||
self, event: hci.HCI_LE_Remote_Connection_Parameter_Request_Event
|
||||
):
|
||||
if event.connection_handle not in self.connections:
|
||||
logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle')
|
||||
return
|
||||
@@ -1313,7 +1370,9 @@ class Host(utils.EventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_long_term_key_request_event(self, event):
|
||||
def on_hci_le_long_term_key_request_event(
|
||||
self, event: hci.HCI_LE_Long_Term_Key_Request_Event
|
||||
):
|
||||
if (connection := self.connections.get(event.connection_handle)) is None:
|
||||
logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle')
|
||||
return
|
||||
@@ -1347,7 +1406,9 @@ class Host(utils.EventEmitter):
|
||||
|
||||
asyncio.create_task(send_long_term_key())
|
||||
|
||||
def on_hci_synchronous_connection_complete_event(self, event):
|
||||
def on_hci_synchronous_connection_complete_event(
|
||||
self, event: hci.HCI_Synchronous_Connection_Complete_Event
|
||||
):
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
# Create/update the connection
|
||||
logger.debug(
|
||||
@@ -1373,7 +1434,9 @@ class Host(utils.EventEmitter):
|
||||
# Notify the client
|
||||
self.emit('sco_connection_failure', event.bd_addr, event.status)
|
||||
|
||||
def on_hci_synchronous_connection_changed_event(self, event):
|
||||
def on_hci_synchronous_connection_changed_event(
|
||||
self, event: hci.HCI_Synchronous_Connection_Changed_Event
|
||||
):
|
||||
pass
|
||||
|
||||
def on_hci_mode_change_event(self, event: hci.HCI_Mode_Change_Event):
|
||||
@@ -1385,7 +1448,7 @@ class Host(utils.EventEmitter):
|
||||
event.interval,
|
||||
)
|
||||
|
||||
def on_hci_role_change_event(self, event):
|
||||
def on_hci_role_change_event(self, event: hci.HCI_Role_Change_Event):
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
logger.debug(
|
||||
f'role change for {event.bd_addr}: '
|
||||
@@ -1399,7 +1462,9 @@ class Host(utils.EventEmitter):
|
||||
)
|
||||
self.emit('role_change_failure', event.bd_addr, event.status)
|
||||
|
||||
def on_hci_le_data_length_change_event(self, event):
|
||||
def on_hci_le_data_length_change_event(
|
||||
self, event: hci.HCI_LE_Data_Length_Change_Event
|
||||
):
|
||||
if (connection := self.connections.get(event.connection_handle)) is None:
|
||||
logger.warning('!!! DATA LENGTH CHANGE: unknown handle')
|
||||
return
|
||||
@@ -1413,7 +1478,9 @@ class Host(utils.EventEmitter):
|
||||
event.max_rx_time,
|
||||
)
|
||||
|
||||
def on_hci_authentication_complete_event(self, event):
|
||||
def on_hci_authentication_complete_event(
|
||||
self, event: hci.HCI_Authentication_Complete_Event
|
||||
):
|
||||
# Notify the client
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
self.emit('connection_authentication', event.connection_handle)
|
||||
@@ -1454,7 +1521,9 @@ class Host(utils.EventEmitter):
|
||||
'connection_encryption_failure', event.connection_handle, event.status
|
||||
)
|
||||
|
||||
def on_hci_encryption_key_refresh_complete_event(self, event):
|
||||
def on_hci_encryption_key_refresh_complete_event(
|
||||
self, event: hci.HCI_Encryption_Key_Refresh_Complete_Event
|
||||
):
|
||||
# Notify the client
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
self.emit('connection_encryption_key_refresh', event.connection_handle)
|
||||
@@ -1465,7 +1534,7 @@ class Host(utils.EventEmitter):
|
||||
event.status,
|
||||
)
|
||||
|
||||
def on_hci_qos_setup_complete_event(self, event):
|
||||
def on_hci_qos_setup_complete_event(self, event: hci.HCI_QOS_Setup_Complete_Event):
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
self.emit(
|
||||
'connection_qos_setup', event.connection_handle, event.service_type
|
||||
@@ -1477,23 +1546,31 @@ class Host(utils.EventEmitter):
|
||||
event.status,
|
||||
)
|
||||
|
||||
def on_hci_link_supervision_timeout_changed_event(self, event):
|
||||
def on_hci_link_supervision_timeout_changed_event(
|
||||
self, event: hci.HCI_Link_Supervision_Timeout_Changed_Event
|
||||
):
|
||||
pass
|
||||
|
||||
def on_hci_max_slots_change_event(self, event):
|
||||
def on_hci_max_slots_change_event(self, event: hci.HCI_Max_Slots_Change_Event):
|
||||
pass
|
||||
|
||||
def on_hci_page_scan_repetition_mode_change_event(self, event):
|
||||
def on_hci_page_scan_repetition_mode_change_event(
|
||||
self, event: hci.HCI_Page_Scan_Repetition_Mode_Change_Event
|
||||
):
|
||||
pass
|
||||
|
||||
def on_hci_link_key_notification_event(self, event):
|
||||
def on_hci_link_key_notification_event(
|
||||
self, event: hci.HCI_Link_Key_Notification_Event
|
||||
):
|
||||
logger.debug(
|
||||
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
|
||||
f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}'
|
||||
)
|
||||
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
|
||||
|
||||
def on_hci_simple_pairing_complete_event(self, event):
|
||||
def on_hci_simple_pairing_complete_event(
|
||||
self, event: hci.HCI_Simple_Pairing_Complete_Event
|
||||
):
|
||||
logger.debug(
|
||||
f'simple pairing complete for {event.bd_addr}: '
|
||||
f'status={hci.HCI_Constant.status_name(event.status)}'
|
||||
@@ -1503,10 +1580,10 @@ class Host(utils.EventEmitter):
|
||||
else:
|
||||
self.emit('classic_pairing_failure', event.bd_addr, event.status)
|
||||
|
||||
def on_hci_pin_code_request_event(self, event):
|
||||
def on_hci_pin_code_request_event(self, event: hci.HCI_PIN_Code_Request_Event):
|
||||
self.emit('pin_code_request', event.bd_addr)
|
||||
|
||||
def on_hci_link_key_request_event(self, event):
|
||||
def on_hci_link_key_request_event(self, event: hci.HCI_Link_Key_Request_Event):
|
||||
async def send_link_key():
|
||||
if self.link_key_provider is None:
|
||||
logger.debug('no link key provider')
|
||||
@@ -1531,10 +1608,14 @@ class Host(utils.EventEmitter):
|
||||
|
||||
asyncio.create_task(send_link_key())
|
||||
|
||||
def on_hci_io_capability_request_event(self, event):
|
||||
def on_hci_io_capability_request_event(
|
||||
self, event: hci.HCI_IO_Capability_Request_Event
|
||||
):
|
||||
self.emit('authentication_io_capability_request', event.bd_addr)
|
||||
|
||||
def on_hci_io_capability_response_event(self, event):
|
||||
def on_hci_io_capability_response_event(
|
||||
self, event: hci.HCI_IO_Capability_Response_Event
|
||||
):
|
||||
self.emit(
|
||||
'authentication_io_capability_response',
|
||||
event.bd_addr,
|
||||
@@ -1542,25 +1623,33 @@ class Host(utils.EventEmitter):
|
||||
event.authentication_requirements,
|
||||
)
|
||||
|
||||
def on_hci_user_confirmation_request_event(self, event):
|
||||
def on_hci_user_confirmation_request_event(
|
||||
self, event: hci.HCI_User_Confirmation_Request_Event
|
||||
):
|
||||
self.emit(
|
||||
'authentication_user_confirmation_request',
|
||||
event.bd_addr,
|
||||
event.numeric_value,
|
||||
)
|
||||
|
||||
def on_hci_user_passkey_request_event(self, event):
|
||||
def on_hci_user_passkey_request_event(
|
||||
self, event: hci.HCI_User_Passkey_Request_Event
|
||||
):
|
||||
self.emit('authentication_user_passkey_request', event.bd_addr)
|
||||
|
||||
def on_hci_user_passkey_notification_event(self, event):
|
||||
def on_hci_user_passkey_notification_event(
|
||||
self, event: hci.HCI_User_Passkey_Notification_Event
|
||||
):
|
||||
self.emit(
|
||||
'authentication_user_passkey_notification', event.bd_addr, event.passkey
|
||||
)
|
||||
|
||||
def on_hci_inquiry_complete_event(self, _event):
|
||||
def on_hci_inquiry_complete_event(self, _event: hci.HCI_Inquiry_Complete_Event):
|
||||
self.emit('inquiry_complete')
|
||||
|
||||
def on_hci_inquiry_result_with_rssi_event(self, event):
|
||||
def on_hci_inquiry_result_with_rssi_event(
|
||||
self, event: hci.HCI_Inquiry_Result_With_RSSI_Event
|
||||
):
|
||||
for bd_addr, class_of_device, rssi in zip(
|
||||
event.bd_addr, event.class_of_device, event.rssi
|
||||
):
|
||||
@@ -1572,7 +1661,9 @@ class Host(utils.EventEmitter):
|
||||
rssi,
|
||||
)
|
||||
|
||||
def on_hci_extended_inquiry_result_event(self, event):
|
||||
def on_hci_extended_inquiry_result_event(
|
||||
self, event: hci.HCI_Extended_Inquiry_Result_Event
|
||||
):
|
||||
self.emit(
|
||||
'inquiry_result',
|
||||
event.bd_addr,
|
||||
@@ -1581,7 +1672,9 @@ class Host(utils.EventEmitter):
|
||||
event.rssi,
|
||||
)
|
||||
|
||||
def on_hci_remote_name_request_complete_event(self, event):
|
||||
def on_hci_remote_name_request_complete_event(
|
||||
self, event: hci.HCI_Remote_Name_Request_Complete_Event
|
||||
):
|
||||
if event.status != hci.HCI_SUCCESS:
|
||||
self.emit('remote_name_failure', event.bd_addr, event.status)
|
||||
else:
|
||||
@@ -1592,14 +1685,18 @@ class Host(utils.EventEmitter):
|
||||
|
||||
self.emit('remote_name', event.bd_addr, utf8_name)
|
||||
|
||||
def on_hci_remote_host_supported_features_notification_event(self, event):
|
||||
def on_hci_remote_host_supported_features_notification_event(
|
||||
self, event: hci.HCI_Remote_Host_Supported_Features_Notification_Event
|
||||
):
|
||||
self.emit(
|
||||
'remote_host_supported_features',
|
||||
event.bd_addr,
|
||||
event.host_supported_features,
|
||||
)
|
||||
|
||||
def on_hci_le_read_remote_features_complete_event(self, event):
|
||||
def on_hci_le_read_remote_features_complete_event(
|
||||
self, event: hci.HCI_LE_Read_Remote_Features_Complete_Event
|
||||
):
|
||||
if event.status != hci.HCI_SUCCESS:
|
||||
self.emit(
|
||||
'le_remote_features_failure', event.connection_handle, event.status
|
||||
@@ -1611,22 +1708,34 @@ class Host(utils.EventEmitter):
|
||||
int.from_bytes(event.le_features, 'little'),
|
||||
)
|
||||
|
||||
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
|
||||
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(
|
||||
self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event
|
||||
):
|
||||
self.emit('cs_remote_supported_capabilities', event)
|
||||
|
||||
def on_hci_le_cs_security_enable_complete_event(self, event):
|
||||
def on_hci_le_cs_security_enable_complete_event(
|
||||
self, event: hci.HCI_LE_CS_Security_Enable_Complete_Event
|
||||
):
|
||||
self.emit('cs_security', event)
|
||||
|
||||
def on_hci_le_cs_config_complete_event(self, event):
|
||||
def on_hci_le_cs_config_complete_event(
|
||||
self, event: hci.HCI_LE_CS_Config_Complete_Event
|
||||
):
|
||||
self.emit('cs_config', event)
|
||||
|
||||
def on_hci_le_cs_procedure_enable_complete_event(self, event):
|
||||
def on_hci_le_cs_procedure_enable_complete_event(
|
||||
self, event: hci.HCI_LE_CS_Procedure_Enable_Complete_Event
|
||||
):
|
||||
self.emit('cs_procedure', event)
|
||||
|
||||
def on_hci_le_cs_subevent_result_event(self, event):
|
||||
def on_hci_le_cs_subevent_result_event(
|
||||
self, event: hci.HCI_LE_CS_Subevent_Result_Event
|
||||
):
|
||||
self.emit('cs_subevent_result', event)
|
||||
|
||||
def on_hci_le_cs_subevent_result_continue_event(self, event):
|
||||
def on_hci_le_cs_subevent_result_continue_event(
|
||||
self, event: hci.HCI_LE_CS_Subevent_Result_Continue_Event
|
||||
):
|
||||
self.emit('cs_subevent_result_continue', event)
|
||||
|
||||
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
|
||||
@@ -1639,5 +1748,5 @@ class Host(utils.EventEmitter):
|
||||
event.supervision_timeout,
|
||||
)
|
||||
|
||||
def on_hci_vendor_event(self, event):
|
||||
def on_hci_vendor_event(self, event: hci.HCI_Vendor_Event):
|
||||
self.emit('vendor_event', event)
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
@@ -272,7 +271,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
def on_connection(connection: Connection) -> None:
|
||||
@connection.on(connection.EVENT_DISCONNECTION)
|
||||
def on_disconnection(_reason) -> None:
|
||||
self.currently_connected_clients.remove(connection)
|
||||
self.currently_connected_clients.discard(connection)
|
||||
|
||||
@connection.on(connection.EVENT_PAIRING)
|
||||
def on_pairing(*_: Any) -> None:
|
||||
@@ -373,8 +372,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.preset_records[key]
|
||||
for key in sorted(self.preset_records.keys())
|
||||
if self.preset_records[key].index >= start_index
|
||||
]
|
||||
del presets[num_presets:]
|
||||
][:num_presets]
|
||||
if len(presets) == 0:
|
||||
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
|
||||
|
||||
@@ -383,7 +381,10 @@ class HearingAccessService(gatt.TemplateService):
|
||||
async def _read_preset_response(
|
||||
self, connection: Connection, presets: list[PresetRecord]
|
||||
):
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
|
||||
# If the ATT bearer is terminated before all notifications or indications are
|
||||
# sent, then the server shall consider the Read Presets Request operation
|
||||
# aborted and shall not either continue or restart the operation when the client
|
||||
# reconnects.
|
||||
try:
|
||||
for i, preset in enumerate(presets):
|
||||
await connection.device.indicate_subscriber(
|
||||
@@ -404,7 +405,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
|
||||
async def generic_update(self, op: PresetChangedOperation) -> None:
|
||||
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
|
||||
await self._notifyPresetOperations(op)
|
||||
await self._notify_preset_operations(op)
|
||||
|
||||
async def delete_preset(self, index: int) -> None:
|
||||
'''Server API to delete a preset. It should not be the current active preset'''
|
||||
@@ -413,14 +414,14 @@ class HearingAccessService(gatt.TemplateService):
|
||||
raise InvalidStateError('Cannot delete active preset')
|
||||
|
||||
del self.preset_records[index]
|
||||
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
|
||||
|
||||
async def available_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset available'''
|
||||
|
||||
preset = self.preset_records[index]
|
||||
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
|
||||
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
|
||||
|
||||
async def unavailable_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset unavailable. It should not be the current active preset'''
|
||||
@@ -432,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
preset.properties.is_available = (
|
||||
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
|
||||
)
|
||||
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
|
||||
|
||||
async def _preset_changed_operation(self, connection: Connection) -> None:
|
||||
'''Send all PresetChangedOperation saved for a given connection'''
|
||||
@@ -447,8 +448,10 @@ class HearingAccessService(gatt.TemplateService):
|
||||
return op.additional_parameters
|
||||
|
||||
op_list.sort(key=get_op_index)
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
|
||||
while len(op_list) > 0:
|
||||
# If the ATT bearer is terminated before all notifications or indications are
|
||||
# sent, then the server shall consider the Preset Changed operation aborted and
|
||||
# shall continue the operation when the client reconnects.
|
||||
while op_list:
|
||||
try:
|
||||
await connection.device.indicate_subscriber(
|
||||
connection,
|
||||
@@ -460,14 +463,15 @@ class HearingAccessService(gatt.TemplateService):
|
||||
except TimeoutError:
|
||||
break
|
||||
|
||||
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
|
||||
for historyList in self.preset_changed_operations_history_per_device.values():
|
||||
historyList.append(op)
|
||||
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
|
||||
for history_list in self.preset_changed_operations_history_per_device.values():
|
||||
history_list.append(op)
|
||||
|
||||
for connection in self.currently_connected_clients:
|
||||
await self._preset_changed_operation(connection)
|
||||
|
||||
async def _on_write_preset_name(self, connection: Connection, value: bytes):
|
||||
del connection # Unused
|
||||
|
||||
if self.read_presets_request_in_progress:
|
||||
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
|
||||
@@ -532,48 +536,51 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.active_preset_index = index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_active_preset(self, _: Connection, value: bytes):
|
||||
async def _on_set_active_preset(self, connection: Connection, value: bytes):
|
||||
del connection # Unused
|
||||
await self.set_active_preset(value)
|
||||
|
||||
async def set_next_or_previous_preset(self, is_previous):
|
||||
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
|
||||
'''Set the next or the previous preset as active'''
|
||||
|
||||
if self.active_preset_index == 0x00:
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
first_preset: Optional[PresetRecord] = None # To loop to first preset
|
||||
next_preset: Optional[PresetRecord] = None
|
||||
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
|
||||
if not record.is_available():
|
||||
continue
|
||||
if first_preset == None:
|
||||
first_preset = record
|
||||
if is_previous:
|
||||
if index >= self.active_preset_index:
|
||||
continue
|
||||
elif index <= self.active_preset_index:
|
||||
continue
|
||||
next_preset = record
|
||||
break
|
||||
presets = sorted(
|
||||
[
|
||||
record
|
||||
for record in self.preset_records.values()
|
||||
if record.is_available()
|
||||
],
|
||||
key=lambda record: record.index,
|
||||
)
|
||||
current_preset = self.preset_records[self.active_preset_index]
|
||||
current_preset_pos = presets.index(current_preset)
|
||||
if is_previous:
|
||||
new_preset = presets[(current_preset_pos - 1) % len(presets)]
|
||||
else:
|
||||
new_preset = presets[(current_preset_pos + 1) % len(presets)]
|
||||
|
||||
if not first_preset: # If no other preset are available
|
||||
if current_preset == new_preset: # If no other preset are available
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
if next_preset:
|
||||
self.active_preset_index = next_preset.index
|
||||
else:
|
||||
self.active_preset_index = first_preset.index
|
||||
self.active_preset_index = new_preset.index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None:
|
||||
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
|
||||
del connection, value # Unused.
|
||||
await self.set_next_or_previous_preset(False)
|
||||
|
||||
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None:
|
||||
async def _on_set_previous_preset(
|
||||
self, connection: Connection, value: bytes
|
||||
) -> None:
|
||||
del connection, value # Unused.
|
||||
await self.set_next_or_previous_preset(True)
|
||||
|
||||
async def _on_set_active_preset_synchronized_locally(
|
||||
self, _: Connection, value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
@@ -584,8 +591,9 @@ class HearingAccessService(gatt.TemplateService):
|
||||
await self.other_server_in_binaural_set.set_active_preset(value)
|
||||
|
||||
async def _on_set_next_preset_synchronized_locally(
|
||||
self, _: Connection, __value__: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection, value # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
@@ -596,8 +604,9 @@ class HearingAccessService(gatt.TemplateService):
|
||||
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
|
||||
|
||||
async def _on_set_previous_preset_synchronized_locally(
|
||||
self, _: Connection, __value__: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection, value # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
@@ -615,11 +624,13 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = HearingAccessService
|
||||
|
||||
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
|
||||
preset_control_point_indications: asyncio.Queue
|
||||
active_preset_index_notification: asyncio.Queue
|
||||
preset_control_point_indications: asyncio.Queue[bytes]
|
||||
active_preset_index_notification: asyncio.Queue[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
self.preset_control_point_indications = asyncio.Queue()
|
||||
self.active_preset_index_notification = asyncio.Queue()
|
||||
|
||||
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
@@ -641,20 +652,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
'B',
|
||||
)
|
||||
|
||||
async def setup_subscription(self):
|
||||
self.preset_control_point_indications = asyncio.Queue()
|
||||
self.active_preset_index_notification = asyncio.Queue()
|
||||
|
||||
def on_active_preset_index_notification(data: bytes):
|
||||
self.active_preset_index_notification.put_nowait(data)
|
||||
|
||||
def on_preset_control_point_indication(data: bytes):
|
||||
self.preset_control_point_indications.put_nowait(data)
|
||||
|
||||
async def setup_subscription(self) -> None:
|
||||
await self.hearing_aid_preset_control_point.subscribe(
|
||||
functools.partial(on_preset_control_point_indication), prefer_notify=False
|
||||
self.preset_control_point_indications.put_nowait,
|
||||
prefer_notify=False,
|
||||
)
|
||||
|
||||
await self.active_preset_index.subscribe(
|
||||
functools.partial(on_active_preset_index_notification)
|
||||
self.active_preset_index_notification.put_nowait
|
||||
)
|
||||
|
||||
@@ -15,67 +15,261 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from bumble import avc, avctp, avrcp, controller, core, device, host, link
|
||||
from bumble.transport import common
|
||||
from bumble import avc, avctp, avrcp
|
||||
|
||||
from . import test_utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TwoDevices:
|
||||
def __init__(self):
|
||||
self.connections = [None, None]
|
||||
|
||||
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
|
||||
self.link = link.LocalLink()
|
||||
self.controllers = [
|
||||
controller.Controller('C1', link=self.link, public_address=addresses[0]),
|
||||
controller.Controller('C2', link=self.link, public_address=addresses[1]),
|
||||
]
|
||||
self.devices = [
|
||||
device.Device(
|
||||
address=addresses[0],
|
||||
host=host.Host(
|
||||
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
|
||||
),
|
||||
),
|
||||
device.Device(
|
||||
address=addresses[1],
|
||||
host=host.Host(
|
||||
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
|
||||
),
|
||||
),
|
||||
]
|
||||
self.devices[0].classic_enabled = True
|
||||
self.devices[1].classic_enabled = True
|
||||
self.connections = [None, None]
|
||||
self.protocols = [None, None]
|
||||
|
||||
def on_connection(self, which, connection):
|
||||
self.connections[which] = connection
|
||||
|
||||
async def setup_connections(self):
|
||||
await self.devices[0].power_on()
|
||||
await self.devices[1].power_on()
|
||||
|
||||
self.connections = await asyncio.gather(
|
||||
self.devices[0].connect(
|
||||
self.devices[1].public_address, core.PhysicalTransport.BR_EDR
|
||||
),
|
||||
self.devices[1].accept(self.devices[0].public_address),
|
||||
)
|
||||
class TwoDevices(test_utils.TwoDevices):
|
||||
protocols: Sequence[avrcp.Protocol] = ()
|
||||
|
||||
async def setup_avdtp_connections(self):
|
||||
self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
|
||||
self.protocols[0].listen(self.devices[1])
|
||||
await self.protocols[1].connect(self.connections[0])
|
||||
|
||||
@classmethod
|
||||
async def create_with_avdtp(cls) -> TwoDevices:
|
||||
devices = await cls.create_with_connection()
|
||||
await devices.setup_avdtp_connections()
|
||||
return devices
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command,",
|
||||
[
|
||||
avrcp.GetPlayStatusCommand(),
|
||||
avrcp.GetCapabilitiesCommand(
|
||||
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID
|
||||
),
|
||||
avrcp.SetAbsoluteVolumeCommand(volume=5),
|
||||
avrcp.GetElementAttributesCommand(
|
||||
identifier=999,
|
||||
attribute_ids=[
|
||||
avrcp.MediaAttributeId.ALBUM_NAME,
|
||||
avrcp.MediaAttributeId.ARTIST_NAME,
|
||||
],
|
||||
),
|
||||
avrcp.RegisterNotificationCommand(
|
||||
event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123
|
||||
),
|
||||
avrcp.SearchCommand(
|
||||
character_set_id=avrcp.CharacterSetId.UTF_8, search_string="Bumble!"
|
||||
),
|
||||
avrcp.PlayItemCommand(
|
||||
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
|
||||
),
|
||||
avrcp.ListPlayerApplicationSettingAttributesCommand(),
|
||||
avrcp.ListPlayerApplicationSettingValuesCommand(
|
||||
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
|
||||
),
|
||||
avrcp.GetCurrentPlayerApplicationSettingValueCommand(
|
||||
attribute=[
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||
]
|
||||
),
|
||||
avrcp.SetPlayerApplicationSettingValueCommand(
|
||||
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
|
||||
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
|
||||
),
|
||||
avrcp.GetPlayerApplicationSettingAttributeTextCommand(
|
||||
attribute=[
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||
]
|
||||
),
|
||||
avrcp.GetPlayerApplicationSettingValueTextCommand(
|
||||
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
value=[
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||
],
|
||||
),
|
||||
avrcp.InformDisplayableCharacterSetCommand(
|
||||
character_set_id=[avrcp.CharacterSetId.UTF_8]
|
||||
),
|
||||
avrcp.InformBatteryStatusOfCtCommand(
|
||||
battery_status=avrcp.InformBatteryStatusOfCtCommand.BatteryStatus.NORMAL
|
||||
),
|
||||
avrcp.SetAddressedPlayerCommand(player_id=1),
|
||||
avrcp.SetBrowsedPlayerCommand(player_id=1),
|
||||
avrcp.GetFolderItemsCommand(
|
||||
scope=avrcp.Scope.NOW_PLAYING,
|
||||
start_item=0,
|
||||
end_item=1,
|
||||
attributes=[avrcp.MediaAttributeId.ARTIST_NAME],
|
||||
),
|
||||
avrcp.ChangePathCommand(
|
||||
uid_counter=1,
|
||||
direction=avrcp.ChangePathCommand.Direction.DOWN,
|
||||
folder_uid=2,
|
||||
),
|
||||
avrcp.GetItemAttributesCommand(
|
||||
scope=avrcp.Scope.NOW_PLAYING,
|
||||
uid=0,
|
||||
uid_counter=1,
|
||||
start_item=0,
|
||||
end_item=0,
|
||||
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
|
||||
),
|
||||
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
|
||||
avrcp.AddToNowPlayingCommand(
|
||||
scope=avrcp.Scope.NOW_PLAYING, uid=0, uid_counter=1
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_command(command: avrcp.Command):
|
||||
assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"event,",
|
||||
[
|
||||
avrcp.UidsChangedEvent(uid_counter=7),
|
||||
avrcp.TrackChangedEvent(identifier=b'12356'),
|
||||
avrcp.VolumeChangedEvent(volume=9),
|
||||
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
|
||||
avrcp.AddressedPlayerChangedEvent(
|
||||
player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10)
|
||||
),
|
||||
avrcp.AvailablePlayersChangedEvent(),
|
||||
avrcp.PlaybackPositionChangedEvent(playback_position=1314),
|
||||
avrcp.NowPlayingContentChangedEvent(),
|
||||
avrcp.PlayerApplicationSettingChangedEvent(
|
||||
player_application_settings=[
|
||||
avrcp.PlayerApplicationSettingChangedEvent.Setting(
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_event(event: avrcp.Event):
|
||||
assert avrcp.Event.from_bytes(bytes(event)) == event
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response,",
|
||||
[
|
||||
avrcp.GetPlayStatusResponse(
|
||||
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
|
||||
),
|
||||
avrcp.GetCapabilitiesResponse(
|
||||
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED,
|
||||
capabilities=[
|
||||
avrcp.EventId.ADDRESSED_PLAYER_CHANGED,
|
||||
avrcp.EventId.BATT_STATUS_CHANGED,
|
||||
],
|
||||
),
|
||||
avrcp.RegisterNotificationResponse(
|
||||
event=avrcp.PlaybackPositionChangedEvent(playback_position=38)
|
||||
),
|
||||
avrcp.SetAbsoluteVolumeResponse(volume=99),
|
||||
avrcp.GetElementAttributesResponse(
|
||||
attributes=[
|
||||
avrcp.MediaAttribute(
|
||||
attribute_id=avrcp.MediaAttributeId.ALBUM_NAME,
|
||||
attribute_value="White Album",
|
||||
character_set_id=avrcp.CharacterSetId.UTF_8,
|
||||
)
|
||||
]
|
||||
),
|
||||
avrcp.ListPlayerApplicationSettingAttributesResponse(
|
||||
attribute=[
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||
]
|
||||
),
|
||||
avrcp.ListPlayerApplicationSettingValuesResponse(
|
||||
value=[
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||
]
|
||||
),
|
||||
avrcp.GetCurrentPlayerApplicationSettingValueResponse(
|
||||
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
|
||||
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
|
||||
),
|
||||
avrcp.SetPlayerApplicationSettingValueResponse(),
|
||||
avrcp.GetPlayerApplicationSettingAttributeTextResponse(
|
||||
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
|
||||
character_set_id=[avrcp.CharacterSetId.UTF_8],
|
||||
attribute_string=["Repeat"],
|
||||
),
|
||||
avrcp.GetPlayerApplicationSettingValueTextResponse(
|
||||
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
|
||||
character_set_id=[avrcp.CharacterSetId.UTF_8],
|
||||
attribute_string=["All track repeat"],
|
||||
),
|
||||
avrcp.InformDisplayableCharacterSetResponse(),
|
||||
avrcp.InformBatteryStatusOfCtResponse(),
|
||||
avrcp.SetAddressedPlayerResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
|
||||
avrcp.SetBrowsedPlayerResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED,
|
||||
uid_counter=1,
|
||||
numbers_of_items=2,
|
||||
character_set_id=avrcp.CharacterSetId.UTF_8,
|
||||
folder_names=["folder1", "folder2"],
|
||||
),
|
||||
avrcp.GetFolderItemsResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED,
|
||||
uid_counter=1,
|
||||
items=[
|
||||
avrcp.MediaPlayerItem(
|
||||
player_id=1,
|
||||
major_player_type=avrcp.MediaPlayerItem.MajorPlayerType.AUDIO,
|
||||
player_sub_type=avrcp.MediaPlayerItem.PlayerSubType.AUDIO_BOOK,
|
||||
play_status=avrcp.PlayStatus.FWD_SEEK,
|
||||
feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING,
|
||||
character_set_id=avrcp.CharacterSetId.UTF_8,
|
||||
displayable_name="Woo",
|
||||
)
|
||||
],
|
||||
),
|
||||
avrcp.ChangePathResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED, number_of_items=2
|
||||
),
|
||||
avrcp.GetItemAttributesResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED,
|
||||
attribute_value_entry_list=[
|
||||
avrcp.AttributeValueEntry(
|
||||
attribute_id=avrcp.MediaAttributeId.GENRE,
|
||||
character_set_id=avrcp.CharacterSetId.UTF_8,
|
||||
attribute_value="uuddlrlrabab",
|
||||
)
|
||||
],
|
||||
),
|
||||
avrcp.GetTotalNumberOfItemsResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED,
|
||||
uid_counter=1,
|
||||
number_of_items=2,
|
||||
),
|
||||
avrcp.SearchResponse(
|
||||
status=avrcp.StatusCode.OPERATION_COMPLETED,
|
||||
uid_counter=1,
|
||||
number_of_items=2,
|
||||
),
|
||||
avrcp.PlayItemResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
|
||||
avrcp.AddToNowPlayingResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
|
||||
],
|
||||
)
|
||||
def test_response(response: avrcp.Response):
|
||||
assert avrcp.Response.from_bytes(bytes(response), response.pdu_id) == response
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_frame_parser():
|
||||
with pytest.raises(ValueError) as error:
|
||||
with pytest.raises(ValueError):
|
||||
avc.Frame.from_bytes(bytes.fromhex("11480000"))
|
||||
|
||||
x = bytes.fromhex("014D0208")
|
||||
@@ -217,8 +411,7 @@ def test_passthrough_commands():
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_supported_events():
|
||||
two_devices = TwoDevices()
|
||||
await two_devices.setup_connections()
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
supported_events = await two_devices.protocols[0].get_supported_events()
|
||||
assert supported_events == []
|
||||
|
||||
Reference in New Issue
Block a user