overall: add types hints to the small subset used by avatar

This commit is contained in:
uael
2023-02-02 17:36:23 +00:00
parent ed261886e1
commit b731f6f556
13 changed files with 234 additions and 166 deletions

View File

@@ -23,7 +23,7 @@ import asyncio
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from typing import ClassVar
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from colors import color
@@ -197,6 +197,8 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
# -----------------------------------------------------------------------------
class Advertisement:
address: Address
TX_POWER_NOT_AVAILABLE = (
HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
)
@@ -511,6 +513,17 @@ ConnectionParametersPreferences.default = ConnectionParametersPreferences()
# -----------------------------------------------------------------------------
class Connection(CompositeEventEmitter):
device: Device
handle: int
transport: int
self_address: Address
peer_address: Address
role: int
encryption: int
authenticated: bool
sc: bool
link_key_type: int
@composite_listener
class Listener:
def on_disconnection(self, reason):
@@ -611,6 +624,10 @@ class Connection(CompositeEventEmitter):
def is_encrypted(self):
return self.encryption != 0
@property
def is_incomplete(self) -> bool:
return self.handle == None
def send_l2cap_pdu(self, cid, pdu):
self.device.send_l2cap_pdu(self.handle, cid, pdu)
@@ -626,20 +643,22 @@ class Connection(CompositeEventEmitter):
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
return await self.device.disconnect(self, reason)
async def disconnect(
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
async def pair(self):
async def pair(self) -> None:
return await self.device.pair(self)
def request_pairing(self):
def request_pairing(self) -> None:
return self.device.request_pairing(self)
# [Classic only]
async def authenticate(self):
async def authenticate(self) -> None:
return await self.device.authenticate(self)
async def encrypt(self, enable=True):
async def encrypt(self, enable: bool = True) -> None:
return await self.device.encrypt(self, enable)
async def sustain(self, timeout=None):
@@ -707,10 +726,10 @@ class Connection(CompositeEventEmitter):
# -----------------------------------------------------------------------------
class DeviceConfiguration:
def __init__(self):
def __init__(self) -> None:
# Setup defaults
self.name = DEVICE_DEFAULT_NAME
self.address = DEVICE_DEFAULT_ADDRESS
self.address = Address(DEVICE_DEFAULT_ADDRESS)
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
@@ -730,12 +749,13 @@ class DeviceConfiguration:
)
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
self.gatt_services = []
self.gatt_services: List[Dict[str, Any]] = []
def load_from_dict(self, config):
def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
if address := config.get('address', None):
self.address = Address(address)
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get(
'advertising_interval', self.advertising_interval_min
@@ -842,6 +862,22 @@ device_host_event_handlers: list[str] = []
# -----------------------------------------------------------------------------
class Device(CompositeEventEmitter):
# incomplete list of fields.
random_address: Address
public_address: Address
classic_enabled: bool
name: str
class_of_device: int
gatt_server: gatt_server.Server
advertising_data: bytes
scan_response_data: bytes
connections: Dict[int, Connection]
pending_connections: Dict[Address, Connection]
classic_pending_accepts: Dict[
Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]]
]
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
@composite_listener
class Listener:
def on_advertisement(self, advertisement):
@@ -888,12 +924,12 @@ class Device(CompositeEventEmitter):
def __init__(
self,
name=None,
address=None,
config=None,
host=None,
generic_access_service=True,
):
name: Optional[str] = None,
address: Optional[Address] = None,
config: Optional[DeviceConfiguration] = None,
host: Optional[Host] = None,
generic_access_service: bool = True,
) -> None:
super().__init__()
self._host = None
@@ -995,10 +1031,12 @@ class Device(CompositeEventEmitter):
setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription')
# Set the initial host
self.host = host
if host:
self.host = host
@property
def host(self):
def host(self) -> Host:
assert self._host
return self._host
@host.setter
@@ -1032,15 +1070,18 @@ class Device(CompositeEventEmitter):
def sdp_service_records(self, service_records):
self.sdp_server.service_records = service_records
def lookup_connection(self, connection_handle):
def lookup_connection(self, connection_handle: int) -> Optional[Connection]:
if connection := self.connections.get(connection_handle):
return connection
return None
def find_connection_by_bd_addr(
self, bd_addr, transport=None, check_address_type=False
):
self,
bd_addr: Address,
transport: Optional[int] = None,
check_address_type: bool = False,
) -> Optional[Connection]:
for connection in self.connections.values():
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
if (
@@ -1098,11 +1139,11 @@ class Device(CompositeEventEmitter):
logger.warning('!!! Command timed out')
raise CommandTimeoutError() from error
async def power_on(self):
async def power_on(self) -> None:
# Reset the controller
await self.host.reset()
response = await self.send_command(HCI_Read_BD_ADDR_Command())
response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg]
if response.return_parameters.status == HCI_SUCCESS:
logger.debug(
color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow')
@@ -1114,7 +1155,7 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
)
) # type: ignore[call-arg]
)
if self.le_enabled:
@@ -1124,7 +1165,7 @@ class Device(CompositeEventEmitter):
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg]
)
# Ensure the address bytes can be a static random address
@@ -1145,7 +1186,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Random_Address_Command(
random_address=self.random_address
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1153,7 +1194,7 @@ class Device(CompositeEventEmitter):
if self.keystore and self.host.supports_command(
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
resolving_keys = await self.keystore.get_resolving_keys()
for (irk, address) in resolving_keys:
@@ -1163,7 +1204,7 @@ class Device(CompositeEventEmitter):
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
)
) # type: ignore[call-arg]
)
# Enable address resolution
@@ -1178,28 +1219,24 @@ class Device(CompositeEventEmitter):
if self.classic_enabled:
await self.send_command(
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8'))
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device)
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Simple_Pairing_Mode_Command(
simple_pairing_mode=int(self.classic_ssp_enabled)
)
) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Secure_Connections_Host_Support_Command(
secure_connections_host_support=int(self.classic_sc_enabled)
)
) # type: ignore[call-arg]
)
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
# Let the SMP manager know about the address
# TODO: allow using a public address
self.smp_manager.address = self.random_address
# Done
self.powered_on = True
@@ -1221,11 +1258,11 @@ class Device(CompositeEventEmitter):
async def start_advertising(
self,
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target=None,
own_address_type=OwnAddressType.RANDOM,
auto_restart=False,
):
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target: Optional[Address] = None,
own_address_type: int = OwnAddressType.RANDOM,
auto_restart: bool = False,
) -> None:
# If we're advertising, stop first
if self.advertising:
await self.stop_advertising()
@@ -1235,7 +1272,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Advertising_Data_Command(
advertising_data=self.advertising_data
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1244,7 +1281,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Scan_Response_Data_Command(
scan_response_data=self.scan_response_data
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1270,13 +1307,13 @@ class Device(CompositeEventEmitter):
peer_address=peer_address,
advertising_channel_map=7,
advertising_filter_policy=0,
),
), # type: ignore[call-arg]
check_result=True,
)
# Enable advertising
await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1),
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg]
check_result=True,
)
@@ -1285,11 +1322,11 @@ class Device(CompositeEventEmitter):
self.advertising_type = advertising_type
self.advertising = True
async def stop_advertising(self):
async def stop_advertising(self) -> None:
# Disable advertising
if self.advertising:
await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg]
check_result=True,
)
@@ -1304,14 +1341,14 @@ class Device(CompositeEventEmitter):
async def start_scanning(
self,
legacy=False,
active=True,
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type=OwnAddressType.RANDOM,
filter_duplicates=False,
scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
):
legacy: bool = False,
active: bool = True,
scan_interval: int = DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type: int = OwnAddressType.RANDOM,
filter_duplicates: bool = False,
scanning_phys: Tuple[int, int] = (HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
) -> None:
# Check that the arguments are legal
if scan_interval < scan_window:
raise ValueError('scan_interval must be >= scan_window')
@@ -1361,7 +1398,7 @@ class Device(CompositeEventEmitter):
scan_types=[scan_type] * scanning_phy_count,
scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count,
scan_windows=[int(scan_window / 0.625)] * scanning_phy_count,
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1372,7 +1409,7 @@ class Device(CompositeEventEmitter):
filter_duplicates=1 if filter_duplicates else 0,
duration=0, # TODO allow other values
period=0, # TODO allow other values
),
), # type: ignore[call-arg]
check_result=True,
)
else:
@@ -1390,7 +1427,7 @@ class Device(CompositeEventEmitter):
le_scan_window=int(scan_window / 0.625),
own_address_type=own_address_type,
scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY,
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1398,25 +1435,25 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Scan_Enable_Command(
le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0
),
), # type: ignore[call-arg]
check_result=True,
)
self.scanning_is_passive = not active
self.scanning = True
async def stop_scanning(self):
async def stop_scanning(self) -> None:
# Disable scanning
if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
await self.send_command(
HCI_LE_Set_Extended_Scan_Enable_Command(
enable=0, filter_duplicates=0, duration=0, period=0
),
), # type: ignore[call-arg]
check_result=True,
)
else:
await self.send_command(
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0),
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg]
check_result=True,
)
@@ -1434,9 +1471,9 @@ class Device(CompositeEventEmitter):
if advertisement := accumulator.update(report):
self.emit('advertisement', advertisement)
async def start_discovery(self, auto_restart=True):
async def start_discovery(self, auto_restart: bool = True) -> None:
await self.send_command(
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE),
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg]
check_result=True,
)
@@ -1445,7 +1482,7 @@ class Device(CompositeEventEmitter):
lap=HCI_GENERAL_INQUIRY_LAP,
inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH,
num_responses=0, # Unlimited number of responses.
)
) # type: ignore[call-arg]
)
if response.status != HCI_Command_Status_Event.PENDING:
self.discovering = False
@@ -1454,9 +1491,9 @@ class Device(CompositeEventEmitter):
self.auto_restart_inquiry = auto_restart
self.discovering = True
async def stop_discovery(self):
async def stop_discovery(self) -> None:
if self.discovering:
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg]
self.auto_restart_inquiry = True
self.discovering = False
@@ -1484,7 +1521,7 @@ class Device(CompositeEventEmitter):
HCI_Write_Scan_Enable_Command(scan_enable=scan_enable)
)
async def set_discoverable(self, discoverable=True):
async def set_discoverable(self, discoverable: bool = True) -> None:
self.discoverable = discoverable
if self.classic_enabled:
# Synthesize an inquiry response if none is set already
@@ -1504,7 +1541,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_Write_Extended_Inquiry_Response_Command(
fec_required=0, extended_inquiry_response=self.inquiry_response
),
), # type: ignore[call-arg]
check_result=True,
)
await self.set_scan_enable(
@@ -1512,7 +1549,7 @@ class Device(CompositeEventEmitter):
page_scan_enabled=self.connectable,
)
async def set_connectable(self, connectable=True):
async def set_connectable(self, connectable: bool = True) -> None:
self.connectable = connectable
if self.classic_enabled:
await self.set_scan_enable(
@@ -1522,12 +1559,14 @@ class Device(CompositeEventEmitter):
async def connect(
self,
peer_address,
transport=BT_LE_TRANSPORT,
connection_parameters_preferences=None,
own_address_type=OwnAddressType.RANDOM,
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
):
peer_address: Union[Address, str],
transport: int = BT_LE_TRANSPORT,
connection_parameters_preferences: Optional[
Dict[int, ConnectionParametersPreferences]
] = None,
own_address_type: int = OwnAddressType.RANDOM,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
'''
Request a connection to a peer.
When transport is BLE, this method cannot be called if there is already a
@@ -1574,6 +1613,8 @@ class Device(CompositeEventEmitter):
):
raise ValueError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, Address)
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
@@ -1691,7 +1732,7 @@ class Device(CompositeEventEmitter):
supervision_timeouts=supervision_timeouts,
min_ce_lengths=min_ce_lengths,
max_ce_lengths=max_ce_lengths,
)
) # type: ignore[call-arg]
)
else:
if HCI_LE_1M_PHY not in connection_parameters_preferences:
@@ -1720,7 +1761,7 @@ class Device(CompositeEventEmitter):
supervision_timeout=int(prefs.supervision_timeout / 10),
min_ce_length=int(prefs.min_ce_length / 0.625),
max_ce_length=int(prefs.max_ce_length / 0.625),
)
) # type: ignore[call-arg]
)
else:
# Save pending connection
@@ -1737,7 +1778,7 @@ class Device(CompositeEventEmitter):
clock_offset=0x0000,
allow_role_switch=0x01,
reserved=0,
)
) # type: ignore[call-arg]
)
if result.status != HCI_Command_Status_Event.PENDING:
@@ -1756,10 +1797,10 @@ class Device(CompositeEventEmitter):
)
except asyncio.TimeoutError:
if transport == BT_LE_TRANSPORT:
await self.send_command(HCI_LE_Create_Connection_Cancel_Command())
await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg]
else:
await self.send_command(
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg]
)
try:
@@ -1777,10 +1818,10 @@ class Device(CompositeEventEmitter):
async def accept(
self,
peer_address=Address.ANY,
role=BT_PERIPHERAL_ROLE,
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
):
peer_address: Union[Address, str] = Address.ANY,
role: int = BT_PERIPHERAL_ROLE,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
'''
Wait and accept any incoming connection or a connection from `peer_address` when
set.
@@ -1802,22 +1843,24 @@ class Device(CompositeEventEmitter):
peer_address, BT_BR_EDR_TRANSPORT
) # TODO: timeout
assert isinstance(peer_address, Address)
if peer_address == Address.NIL:
raise ValueError('accept on nil address')
# Create a future so that we can wait for the request
pending_request = asyncio.get_running_loop().create_future()
pending_request_fut = asyncio.get_running_loop().create_future()
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].append(pending_request)
self.classic_pending_accepts[Address.ANY].append(pending_request_fut)
elif peer_address in self.classic_pending_accepts:
raise InvalidStateError('accept connection already pending')
else:
self.classic_pending_accepts[peer_address] = pending_request
self.classic_pending_accepts[peer_address] = [pending_request_fut]
try:
# Wait for a request or a completed connection
pending_request = self.abort_on('flush', pending_request)
pending_request = self.abort_on('flush', pending_request_fut)
result = await (
asyncio.wait_for(pending_request, timeout)
if timeout
@@ -1826,7 +1869,7 @@ class Device(CompositeEventEmitter):
except Exception:
# Remove future from device context
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].remove(pending_request)
self.classic_pending_accepts[Address.ANY].remove(pending_request_fut)
else:
self.classic_pending_accepts.pop(peer_address)
raise
@@ -1838,6 +1881,7 @@ class Device(CompositeEventEmitter):
# Otherwise, result came from `on_connection_request`
peer_address, _class_of_device, _link_type = result
assert isinstance(peer_address, Address)
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
@@ -1867,7 +1911,7 @@ class Device(CompositeEventEmitter):
try:
# Accept connection request
await self.send_command(
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role)
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg]
)
# Wait for connection complete
@@ -2243,7 +2287,7 @@ class Device(CompositeEventEmitter):
)
# [Classic only]
async def request_remote_name(self, remote): # remote: Connection | Address
async def request_remote_name(self, remote: Union[Address, Connection]) -> str:
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
@@ -2271,7 +2315,7 @@ class Device(CompositeEventEmitter):
page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2,
reserved=0,
clock_offset=0, # TODO investigate non-0 values
)
) # type: ignore[call-arg]
)
if result.status != HCI_COMMAND_STATUS_PENDING:
@@ -2372,7 +2416,7 @@ class Device(CompositeEventEmitter):
# In this case, set the completed `connection` to the `accept` future
# result.
if peer_address in self.classic_pending_accepts:
future = self.classic_pending_accepts.pop(peer_address)
future, *_ = self.classic_pending_accepts.pop(peer_address)
future.set_result(connection)
# Emit an event to notify listeners of the new connection
@@ -2473,7 +2517,7 @@ class Device(CompositeEventEmitter):
# match a pending future using `bd_addr`
if bd_addr in self.classic_pending_accepts:
future = self.classic_pending_accepts.pop(bd_addr)
future, *_ = self.classic_pending_accepts.pop(bd_addr)
future.set_result((bd_addr, class_of_device, link_type))
# match first pending future for ANY address