Merge pull request #764 from khsiao-google/update

Add typing for device.py
This commit is contained in:
khsiao-google
2025-09-01 15:19:57 +08:00
committed by GitHub

View File

@@ -48,6 +48,7 @@ from typing_extensions import Self
from bumble import (
core,
data_types,
gatt,
gatt_client,
gatt_server,
hci,
@@ -264,7 +265,7 @@ class ExtendedAdvertisement(Advertisement):
# -----------------------------------------------------------------------------
class AdvertisementDataAccumulator:
def __init__(self, passive=False):
def __init__(self, passive: bool = False):
self.passive = passive
self.last_advertisement = None
self.last_data = b''
@@ -1243,7 +1244,7 @@ class LePhyOptions:
PREFER_S_2_CODED_PHY = 1
PREFER_S_8_CODED_PHY = 2
def __init__(self, coded_phy_preference=0):
def __init__(self, coded_phy_preference: int = 0):
self.coded_phy_preference = coded_phy_preference
def __int__(self):
@@ -1691,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter):
self_address: hci.Address
self_resolvable_address: Optional[hci.Address]
peer_address: hci.Address
peer_name: Optional[str]
peer_resolvable_address: Optional[hci.Address]
peer_le_features: Optional[hci.LeFeatureMask]
role: hci.Role
@@ -1931,7 +1933,7 @@ class Connection(utils.CompositeEventEmitter):
self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result)
self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None:
async def set_data_length(self, tx_octets: int, tx_time: int) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)
async def update_parameters(
@@ -1961,7 +1963,12 @@ class Connection(utils.CompositeEventEmitter):
use_l2cap=use_l2cap,
)
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
async def set_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options)
async def get_phy(self) -> ConnectionPHY:
@@ -2156,7 +2163,7 @@ class DeviceConfiguration:
# Decorator that converts the first argument from a connection handle to a connection
def with_connection_from_handle(function):
@functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs):
def wrapper(self, connection_handle: int, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None:
raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
@@ -2169,7 +2176,7 @@ def with_connection_from_handle(function):
# Decorator that converts the first argument from a bluetooth address to a connection
def with_connection_from_address(function):
@functools.wraps(function)
def wrapper(self, address, *args, **kwargs):
def wrapper(self, address: hci.Address, *args, **kwargs):
if connection := self.pending_connections.get(address, False):
return function(self, connection, *args, **kwargs)
for connection in self.connections.values():
@@ -2655,7 +2662,7 @@ class Device(utils.CompositeEventEmitter):
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
async def send_command(self, command, check_result=False):
async def send_command(self, command: hci.HCI_Command, check_result: bool = False):
try:
return await asyncio.wait_for(
self.host.send_command(command, check_result), self.command_timeout
@@ -2912,13 +2919,13 @@ class Device(utils.CompositeEventEmitter):
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
return self.host.supports_le_features(feature)
def supports_le_phy(self, phy: int) -> bool:
if phy == hci.HCI_LE_1M_PHY:
def supports_le_phy(self, phy: hci.Phy) -> bool:
if phy == hci.Phy.LE_1M:
return True
feature_map: dict[int, hci.LeFeatureMask] = {
hci.HCI_LE_2M_PHY: hci.LeFeatureMask.LE_2M_PHY,
hci.HCI_LE_CODED_PHY: hci.LeFeatureMask.LE_CODED_PHY,
feature_map: dict[hci.Phy, hci.LeFeatureMask] = {
hci.Phy.LE_2M: hci.LeFeatureMask.LE_2M_PHY,
hci.Phy.LE_CODED: hci.LeFeatureMask.LE_CODED_PHY,
}
if phy not in feature_map:
raise InvalidArgumentError('invalid PHY')
@@ -3522,7 +3529,9 @@ class Device(utils.CompositeEventEmitter):
self.discovering = False
@host_event_handler
def on_inquiry_result(self, address, class_of_device, data, rssi):
def on_inquiry_result(
self, address: hci.Address, class_of_device: int, data: bytes, rssi: int
):
self.emit(
self.EVENT_INQUIRY_RESULT,
address,
@@ -3531,7 +3540,9 @@ class Device(utils.CompositeEventEmitter):
rssi,
)
async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled):
async def set_scan_enable(
self, inquiry_scan_enabled: bool, page_scan_enabled: bool
):
if inquiry_scan_enabled and page_scan_enabled:
scan_enable = 0x03
elif page_scan_enabled:
@@ -3657,6 +3668,7 @@ class Device(utils.CompositeEventEmitter):
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, transport
) # TODO: timeout
@@ -3684,7 +3696,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if transport == PhysicalTransport.LE or (
# match BR/EDR connection failure event against peer address
error.transport == transport
@@ -3904,6 +3916,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -3962,7 +3975,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if (
error.transport == PhysicalTransport.BR_EDR
and error.peer_address == peer_address
@@ -3999,7 +4012,7 @@ class Device(utils.CompositeEventEmitter):
self.pending_connections.pop(peer_address, None)
@asynccontextmanager
async def connect_as_gatt(self, peer_address):
async def connect_as_gatt(self, peer_address: Union[hci.Address, str]):
async with AsyncExitStack() as stack:
connection = await stack.enter_async_context(
await self.connect(peer_address)
@@ -4035,6 +4048,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -4080,7 +4094,9 @@ class Device(utils.CompositeEventEmitter):
)
self.disconnecting = False
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
async def set_data_length(
self, connection: Connection, tx_octets: int, tx_time: int
) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB:
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
@@ -4183,7 +4199,11 @@ class Device(utils.CompositeEventEmitter):
)
async def set_connection_phy(
self, connection, tx_phys=None, rx_phys=None, phy_options=None
self,
connection: Connection,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
if not self.host.supports_command(hci.HCI_LE_SET_PHY_COMMAND):
logger.warning('ignoring request, command not supported')
@@ -4199,7 +4219,7 @@ class Device(utils.CompositeEventEmitter):
all_phys=all_phys_bits,
tx_phys=hci.phy_list_to_bits(tx_phys),
rx_phys=hci.phy_list_to_bits(rx_phys),
phy_options=0 if phy_options is None else int(phy_options),
phy_options=phy_options,
)
)
@@ -4210,7 +4230,11 @@ class Device(utils.CompositeEventEmitter):
)
raise hci.HCI_StatusError(result)
async def set_default_phy(self, tx_phys=None, rx_phys=None):
async def set_default_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
):
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
)
@@ -4248,7 +4272,7 @@ class Device(utils.CompositeEventEmitter):
check_result=True,
)
async def find_peer_by_name(self, name, transport=PhysicalTransport.LE):
async def find_peer_by_name(self, name: str, transport=PhysicalTransport.LE):
"""
Scan for a peer with a given name and return its address.
"""
@@ -4263,7 +4287,7 @@ class Device(utils.CompositeEventEmitter):
if local_name == name:
peer_address.set_result(address)
listener = None
listener: Optional[Callable[..., None]] = None
was_scanning = self.scanning
was_discovering = self.discovering
try:
@@ -4369,10 +4393,10 @@ class Device(utils.CompositeEventEmitter):
def smp_session_proxy(self, session_proxy: type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy
async def pair(self, connection):
async def pair(self, connection: Connection):
return await self.smp_manager.pair(connection)
def request_pairing(self, connection):
def request_pairing(self, connection: Connection):
return self.smp_manager.request_pairing(connection)
async def get_long_term_key(
@@ -4460,7 +4484,7 @@ class Device(utils.CompositeEventEmitter):
on_authentication_failure,
)
async def encrypt(self, connection, enable=True):
async def encrypt(self, connection: Connection, enable: bool = True):
if not enable and connection.transport == PhysicalTransport.LE:
raise InvalidArgumentError('`enable` parameter is classic only.')
@@ -4470,7 +4494,7 @@ class Device(utils.CompositeEventEmitter):
def on_encryption_change():
pending_encryption.set_result(None)
def on_encryption_failure(error_code):
def on_encryption_failure(error_code: int):
pending_encryption.set_exception(hci.HCI_Error(error_code))
connection.on(
@@ -4562,10 +4586,10 @@ class Device(utils.CompositeEventEmitter):
async def switch_role(self, connection: Connection, role: hci.Role):
pending_role_change = asyncio.get_running_loop().create_future()
def on_role_change(new_role):
def on_role_change(new_role: hci.Role):
pending_role_change.set_result(new_role)
def on_role_change_failure(error_code):
def on_role_change_failure(error_code: int):
pending_role_change.set_exception(hci.HCI_Error(error_code))
connection.on(connection.EVENT_ROLE_CHANGE, on_role_change)
@@ -5155,10 +5179,10 @@ class Device(utils.CompositeEventEmitter):
):
connection.emit(connection.EVENT_LINK_KEY)
def add_service(self, service):
def add_service(self, service: gatt.Service):
self.gatt_server.add_service(service)
def add_services(self, services):
def add_services(self, services: Iterable[gatt.Service]):
self.gatt_server.add_services(services)
def add_default_services(
@@ -5254,10 +5278,10 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
def on_advertising_set_termination(
self,
status,
advertising_handle,
connection_handle,
number_of_completed_extended_advertising_events,
status: int,
advertising_handle: int,
connection_handle: int,
number_of_completed_extended_advertising_events: int,
):
# Legacy advertising set is also one of extended advertising sets.
if not (
@@ -5556,7 +5580,12 @@ class Device(utils.CompositeEventEmitter):
)
@host_event_handler
def on_connection_failure(self, transport, peer_address, error_code):
def on_connection_failure(
self,
transport: hci.PhysicalTransport,
peer_address: hci.Address,
error_code: int,
):
logger.debug(
f'*** Connection failed: {hci.HCI_Constant.error_name(error_code)}'
)
@@ -5675,7 +5704,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication(self, connection):
def on_connection_authentication(self, connection: Connection):
logger.debug(
f'*** Connection Authentication: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -5685,7 +5714,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication_failure(self, connection, error):
def on_connection_authentication_failure(
self, connection: Connection, error: core.ConnectionError
):
logger.debug(
f'*** Connection Authentication Failure: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, error={error}'
@@ -5727,10 +5758,13 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_address
def on_authentication_io_capability_response(
self, connection, io_capability, authentication_requirements
self,
connection: Connection,
io_capability: int,
authentication_requirements: int,
):
connection.peer_pairing_io_capability = io_capability
connection.peer_pairing_authentication_requirements = (
connection.pairing_peer_io_capability = io_capability
connection.pairing_peer_authentication_requirements = (
authentication_requirements
)
@@ -5741,7 +5775,7 @@ class Device(utils.CompositeEventEmitter):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
peer_io_capability = connection.peer_pairing_io_capability
peer_io_capability = connection.pairing_peer_io_capability
async def confirm() -> bool:
# Ask the user to confirm the pairing, without display
@@ -5816,7 +5850,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_request(self, connection) -> None:
def on_authentication_user_passkey_request(self, connection: Connection) -> None:
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5859,7 +5893,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_pin_code_request(self, connection):
def on_pin_code_request(self, connection: Connection):
# Classic legacy pairing
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5903,7 +5937,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_notification(self, connection, passkey):
def on_authentication_user_passkey_notification(
self, connection: Connection, passkey: int
):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5915,14 +5951,15 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name(self, connection: Connection, address, remote_name):
def on_remote_name(
self, connection: Connection, address: hci.Address, remote_name: bytes
):
# Try to decode the name
try:
remote_name = remote_name.decode('utf-8')
if connection:
connection.peer_name = remote_name
connection.peer_name = remote_name.decode('utf-8')
connection.emit(connection.EVENT_REMOTE_NAME)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name.decode('utf-8'))
except UnicodeDecodeError as error:
logger.warning('peer name is not valid UTF-8')
if connection:
@@ -5933,7 +5970,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name_failure(self, connection: Connection, address, error):
def on_remote_name_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
@@ -6134,7 +6173,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_encryption_key_refresh(self, connection):
def on_connection_encryption_key_refresh(self, connection: Connection):
logger.debug(
f'*** Connection Key Refresh: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -6172,7 +6211,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_parameters_update_failure(self, connection, error):
def on_connection_parameters_update_failure(
self, connection: Connection, error: int
):
logger.debug(
f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6182,7 +6223,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update(self, connection, phy):
def on_connection_phy_update(self, connection: Connection, phy: core.ConnectionPHY):
logger.debug(
f'*** Connection PHY Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6192,7 +6233,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update_failure(self, connection, error):
def on_connection_phy_update_failure(self, connection: Connection, error: int):
logger.debug(
f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6221,7 +6262,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection, att_mtu):
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6233,7 +6274,12 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_data_length_change(
self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time
self,
connection: Connection,
max_tx_octets: int,
max_tx_time: int,
max_rx_octets: int,
max_rx_time: int,
):
logger.debug(
f'*** Connection Data Length Change: [0x{connection.handle:04X}] '
@@ -6358,14 +6404,16 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_role_change(self, connection, new_role):
def on_role_change(self, connection: Connection, new_role: hci.Role):
connection.role = new_role
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_role_change_failure(self, connection, address, error):
def on_role_change_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error)
self.emit(self.EVENT_ROLE_CHANGE_FAILURE, address, error)
@@ -6379,7 +6427,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status) -> None:
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
def on_pairing_start(self, connection: Connection) -> None:
@@ -6403,7 +6451,7 @@ class Device(utils.CompositeEventEmitter):
connection.emit(connection.EVENT_PAIRING_FAILURE, reason)
@with_connection_from_handle
def on_gatt_pdu(self, connection, pdu):
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu)
@@ -6425,7 +6473,7 @@ class Device(utils.CompositeEventEmitter):
connection.gatt_server.on_gatt_pdu(connection, att_pdu)
@with_connection_from_handle
def on_smp_pdu(self, connection, pdu):
def on_smp_pdu(self, connection: Connection, pdu: bytes):
self.smp_manager.on_smp_pdu(connection, pdu)
@host_event_handler