forked from auracaster/bumble_mirror
typing surrport for HCI commands return parameters
This commit is contained in:
293
bumble/host.py
293
bumble/host.py
@@ -23,11 +23,15 @@ import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
|
||||
|
||||
from bumble import drivers, hci, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport
|
||||
from bumble.core import (
|
||||
ConnectionPHY,
|
||||
InvalidStateError,
|
||||
PhysicalTransport,
|
||||
)
|
||||
from bumble.l2cap import L2CAP_PDU
|
||||
from bumble.snoop import Snooper
|
||||
from bumble.transport.common import TransportLostError
|
||||
@@ -35,7 +39,6 @@ from bumble.transport.common import TransportLostError
|
||||
if TYPE_CHECKING:
|
||||
from bumble.transport.common import TransportSink, TransportSource
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -236,6 +239,9 @@ class IsoLink:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
|
||||
|
||||
|
||||
class Host(utils.EventEmitter):
|
||||
connections: dict[int, Connection]
|
||||
cis_links: dict[int, IsoLink]
|
||||
@@ -264,11 +270,13 @@ class Host(utils.EventEmitter):
|
||||
self.bis_links = {} # BIS links, by connection handle
|
||||
self.sco_links = {} # SCO links, by connection handle
|
||||
self.bigs = {} # BIG Handle to BIS Handles
|
||||
self.pending_command = None
|
||||
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
|
||||
self.pending_response: asyncio.Future[Any] | None = None
|
||||
self.number_of_supported_advertising_sets = 0
|
||||
self.maximum_advertising_data_length = 31
|
||||
self.local_version = None
|
||||
self.local_version: (
|
||||
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
|
||||
) = None
|
||||
self.local_supported_commands = 0
|
||||
self.local_le_features = 0
|
||||
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
|
||||
@@ -312,7 +320,7 @@ class Host(utils.EventEmitter):
|
||||
self.emit('flush')
|
||||
self.command_semaphore.release()
|
||||
|
||||
async def reset(self, driver_factory=drivers.get_driver_for_host):
|
||||
async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
|
||||
if self.ready:
|
||||
self.ready = False
|
||||
await self.flush()
|
||||
@@ -330,57 +338,53 @@ class Host(utils.EventEmitter):
|
||||
|
||||
# Send a reset command unless a driver has already done so.
|
||||
if reset_needed:
|
||||
await self.send_command(hci.HCI_Reset_Command(), check_result=True)
|
||||
await self.send_sync_command(hci.HCI_Reset_Command())
|
||||
self.ready = True
|
||||
|
||||
response = await self.send_command(
|
||||
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True
|
||||
response1 = await self.send_sync_command(
|
||||
hci.HCI_Read_Local_Supported_Commands_Command()
|
||||
)
|
||||
self.local_supported_commands = int.from_bytes(
|
||||
response.return_parameters.supported_commands, 'little'
|
||||
response1.supported_commands, 'little'
|
||||
)
|
||||
|
||||
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
|
||||
response2 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Local_Supported_Features_Command()
|
||||
)
|
||||
self.local_le_features = struct.unpack(
|
||||
'<Q', response.return_parameters.le_features
|
||||
)[0]
|
||||
self.local_le_features = struct.unpack('<Q', response2.le_features)[0]
|
||||
|
||||
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
self.local_version = await self.send_sync_command(
|
||||
hci.HCI_Read_Local_Version_Information_Command()
|
||||
)
|
||||
self.local_version = response.return_parameters
|
||||
|
||||
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
|
||||
max_page_number = 0
|
||||
page_number = 0
|
||||
lmp_features = 0
|
||||
while page_number <= max_page_number:
|
||||
response = await self.send_command(
|
||||
response4 = await self.send_sync_command(
|
||||
hci.HCI_Read_Local_Extended_Features_Command(
|
||||
page_number=page_number
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
)
|
||||
lmp_features |= int.from_bytes(
|
||||
response.return_parameters.extended_lmp_features, 'little'
|
||||
response4.extended_lmp_features, 'little'
|
||||
) << (64 * page_number)
|
||||
max_page_number = response.return_parameters.maximum_page_number
|
||||
max_page_number = response4.maximum_page_number
|
||||
page_number += 1
|
||||
self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
|
||||
|
||||
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True
|
||||
response5 = await self.send_sync_command(
|
||||
hci.HCI_Read_Local_Supported_Features_Command()
|
||||
)
|
||||
self.local_lmp_features = hci.LmpFeatureMask(
|
||||
int.from_bytes(response.return_parameters.lmp_features, 'little')
|
||||
int.from_bytes(response5.lmp_features, 'little')
|
||||
)
|
||||
|
||||
await self.send_command(
|
||||
await self.send_sync_command(
|
||||
hci.HCI_Set_Event_Mask_Command(
|
||||
event_mask=hci.HCI_Set_Event_Mask_Command.mask(
|
||||
[
|
||||
@@ -437,7 +441,7 @@ class Host(utils.EventEmitter):
|
||||
)
|
||||
)
|
||||
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
|
||||
await self.send_command(
|
||||
await self.send_sync_command(
|
||||
hci.HCI_Set_Event_Mask_Page_2_Command(
|
||||
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
|
||||
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
|
||||
@@ -499,20 +503,14 @@ class Host(utils.EventEmitter):
|
||||
]
|
||||
)
|
||||
|
||||
await self.send_command(
|
||||
await self.send_sync_command(
|
||||
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
|
||||
)
|
||||
|
||||
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
hc_acl_data_packet_length = (
|
||||
response.return_parameters.hc_acl_data_packet_length
|
||||
)
|
||||
hc_total_num_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_acl_data_packets
|
||||
)
|
||||
response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
|
||||
hc_acl_data_packet_length = response6.hc_acl_data_packet_length
|
||||
hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
|
||||
|
||||
logger.debug(
|
||||
'HCI ACL flow control: '
|
||||
@@ -531,19 +529,13 @@ class Host(utils.EventEmitter):
|
||||
iso_data_packet_length = 0
|
||||
total_num_iso_data_packets = 0
|
||||
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
|
||||
)
|
||||
le_acl_data_packet_length = (
|
||||
response.return_parameters.le_acl_data_packet_length
|
||||
)
|
||||
total_num_le_acl_data_packets = (
|
||||
response.return_parameters.total_num_le_acl_data_packets
|
||||
)
|
||||
iso_data_packet_length = response.return_parameters.iso_data_packet_length
|
||||
total_num_iso_data_packets = (
|
||||
response.return_parameters.total_num_iso_data_packets
|
||||
response7 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_V2_Command()
|
||||
)
|
||||
le_acl_data_packet_length = response7.le_acl_data_packet_length
|
||||
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
|
||||
iso_data_packet_length = response7.iso_data_packet_length
|
||||
total_num_iso_data_packets = response7.total_num_iso_data_packets
|
||||
|
||||
logger.debug(
|
||||
'HCI LE flow control: '
|
||||
@@ -553,15 +545,11 @@ class Host(utils.EventEmitter):
|
||||
f'total_num_iso_data_packets={total_num_iso_data_packets}'
|
||||
)
|
||||
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
le_acl_data_packet_length = (
|
||||
response.return_parameters.le_acl_data_packet_length
|
||||
)
|
||||
total_num_le_acl_data_packets = (
|
||||
response.return_parameters.total_num_le_acl_data_packets
|
||||
response8 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_Command()
|
||||
)
|
||||
le_acl_data_packet_length = response8.le_acl_data_packet_length
|
||||
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
|
||||
|
||||
logger.debug(
|
||||
'HCI LE ACL flow control: '
|
||||
@@ -592,16 +580,16 @@ class Host(utils.EventEmitter):
|
||||
) and self.supports_command(
|
||||
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
|
||||
):
|
||||
response = await self.send_command(
|
||||
response9 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
|
||||
)
|
||||
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
|
||||
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
|
||||
suggested_max_tx_octets = response9.suggested_max_tx_octets
|
||||
suggested_max_tx_time = response9.suggested_max_tx_time
|
||||
if (
|
||||
suggested_max_tx_octets != self.suggested_max_tx_octets
|
||||
or suggested_max_tx_time != self.suggested_max_tx_time
|
||||
):
|
||||
await self.send_command(
|
||||
await self.send_sync_command(
|
||||
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
|
||||
suggested_max_tx_octets=self.suggested_max_tx_octets,
|
||||
suggested_max_tx_time=self.suggested_max_tx_time,
|
||||
@@ -611,23 +599,21 @@ class Host(utils.EventEmitter):
|
||||
if self.supports_command(
|
||||
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
|
||||
):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(),
|
||||
check_result=True,
|
||||
response10 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
|
||||
)
|
||||
self.number_of_supported_advertising_sets = (
|
||||
response.return_parameters.num_supported_advertising_sets
|
||||
response10.num_supported_advertising_sets
|
||||
)
|
||||
|
||||
if self.supports_command(
|
||||
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
|
||||
):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(),
|
||||
check_result=True,
|
||||
response11 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
|
||||
)
|
||||
self.maximum_advertising_data_length = (
|
||||
response.return_parameters.max_advertising_data_length
|
||||
response11.max_advertising_data_length
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -654,9 +640,11 @@ class Host(utils.EventEmitter):
|
||||
if self.hci_sink:
|
||||
self.hci_sink.on_packet(bytes(packet))
|
||||
|
||||
async def send_command(
|
||||
self, command, check_result=False, response_timeout: int | None = None
|
||||
):
|
||||
async def _send_command(
|
||||
self,
|
||||
command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
|
||||
response_timeout: float | None = None,
|
||||
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
|
||||
# Wait until we can send (only one pending command at a time)
|
||||
async with self.command_semaphore:
|
||||
assert self.pending_command is None
|
||||
@@ -668,29 +656,9 @@ class Host(utils.EventEmitter):
|
||||
|
||||
try:
|
||||
self.send_hci_packet(command)
|
||||
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
|
||||
response = self.pending_response.result()
|
||||
|
||||
# Check the return parameters if required
|
||||
if check_result:
|
||||
if isinstance(response, hci.HCI_Command_Status_Event):
|
||||
status = response.status # type: ignore[attr-defined]
|
||||
elif isinstance(response.return_parameters, int):
|
||||
status = response.return_parameters
|
||||
elif isinstance(response.return_parameters, bytes):
|
||||
# return parameters first field is a one byte status code
|
||||
status = response.return_parameters[0]
|
||||
else:
|
||||
status = response.return_parameters.status
|
||||
|
||||
if status != hci.HCI_SUCCESS:
|
||||
logger.warning(
|
||||
f'{command.name} failed '
|
||||
f'({hci.HCI_Constant.error_name(status)})'
|
||||
)
|
||||
raise hci.HCI_Error(status)
|
||||
|
||||
return response
|
||||
return await asyncio.wait_for(
|
||||
self.pending_response, timeout=response_timeout
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception while sending command:", "red"))
|
||||
raise
|
||||
@@ -698,12 +666,107 @@ class Host(utils.EventEmitter):
|
||||
self.pending_command = None
|
||||
self.pending_response = None
|
||||
|
||||
# Use this method to send a command from a task
|
||||
def send_command_sync(self, command: hci.HCI_Command) -> None:
|
||||
async def send_command(command: hci.HCI_Command) -> None:
|
||||
await self.send_command(command)
|
||||
@overload
|
||||
async def send_command(
|
||||
self,
|
||||
command: hci.HCI_SyncCommand[_RP],
|
||||
check_result: bool = False,
|
||||
response_timeout: float | None = None,
|
||||
) -> hci.HCI_Command_Complete_Event[_RP]: ...
|
||||
|
||||
asyncio.create_task(send_command(command))
|
||||
@overload
|
||||
async def send_command(
|
||||
self,
|
||||
command: hci.HCI_AsyncCommand,
|
||||
check_result: bool = False,
|
||||
response_timeout: float | None = None,
|
||||
) -> hci.HCI_Command_Status_Event: ...
|
||||
|
||||
async def send_command(
|
||||
self,
|
||||
command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
|
||||
check_result: bool = False,
|
||||
response_timeout: float | None = None,
|
||||
) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
|
||||
response = await self._send_command(command, response_timeout)
|
||||
|
||||
# Check the return parameters if required
|
||||
if check_result:
|
||||
if isinstance(response, hci.HCI_Command_Status_Event):
|
||||
status = response.status # type: ignore[attr-defined]
|
||||
elif isinstance(response.return_parameters, int):
|
||||
status = response.return_parameters
|
||||
elif isinstance(response.return_parameters, bytes):
|
||||
# return parameters first field is a one byte status code
|
||||
status = response.return_parameters[0]
|
||||
elif isinstance(
|
||||
response.return_parameters, hci.HCI_GenericReturnParameters
|
||||
):
|
||||
# FIXME: temporary workaround
|
||||
# NO STATUS
|
||||
status = hci.HCI_SUCCESS
|
||||
else:
|
||||
status = response.return_parameters.status
|
||||
|
||||
if status != hci.HCI_SUCCESS:
|
||||
logger.warning(
|
||||
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
|
||||
)
|
||||
raise hci.HCI_Error(status)
|
||||
|
||||
return response
|
||||
|
||||
async def send_sync_command(
|
||||
self,
|
||||
command: hci.HCI_SyncCommand[_RP],
|
||||
check_status: bool = True,
|
||||
response_timeout: float | None = None,
|
||||
) -> _RP:
|
||||
response = await self._send_command(command, response_timeout)
|
||||
|
||||
# Check that the response is of the expected type
|
||||
assert isinstance(response, hci.HCI_Command_Complete_Event)
|
||||
return_parameters: _RP = response.return_parameters
|
||||
assert isinstance(return_parameters, command.return_parameters_class)
|
||||
|
||||
# Check the return parameters if required
|
||||
if check_status:
|
||||
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
|
||||
status = return_parameters.status
|
||||
if status != hci.HCI_SUCCESS:
|
||||
logger.warning(
|
||||
f'{command.name} failed '
|
||||
f'({hci.HCI_Constant.error_name(status)})'
|
||||
)
|
||||
raise hci.HCI_Error(status)
|
||||
|
||||
return return_parameters
|
||||
|
||||
async def send_async_command(
|
||||
self,
|
||||
command: hci.HCI_AsyncCommand,
|
||||
check_status: bool = True,
|
||||
response_timeout: float | None = None,
|
||||
) -> hci.HCI_ErrorCode:
|
||||
response = await self._send_command(command, response_timeout)
|
||||
|
||||
# Check that the response is of the expected type
|
||||
assert isinstance(response, hci.HCI_Command_Status_Event)
|
||||
|
||||
# Check the return parameters if required
|
||||
status = response.status
|
||||
if check_status:
|
||||
if status != hci.HCI_CommandStatus.PENDING:
|
||||
logger.warning(
|
||||
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
|
||||
)
|
||||
raise hci.HCI_Error(status)
|
||||
|
||||
return hci.HCI_ErrorCode(status)
|
||||
|
||||
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
|
||||
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
|
||||
utils.AsyncRunner.spawn(self.send_async_command(command))
|
||||
|
||||
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
@@ -1338,15 +1401,17 @@ class Host(utils.EventEmitter):
|
||||
|
||||
# For now, just accept everything
|
||||
# TODO: delegate the decision
|
||||
self.send_command_sync(
|
||||
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
|
||||
connection_handle=event.connection_handle,
|
||||
interval_min=event.interval_min,
|
||||
interval_max=event.interval_max,
|
||||
max_latency=event.max_latency,
|
||||
timeout=event.timeout,
|
||||
min_ce_length=0,
|
||||
max_ce_length=0,
|
||||
utils.AsyncRunner.spawn(
|
||||
self.send_sync_command(
|
||||
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
|
||||
connection_handle=event.connection_handle,
|
||||
interval_min=event.interval_min,
|
||||
interval_max=event.interval_max,
|
||||
max_latency=event.max_latency,
|
||||
timeout=event.timeout,
|
||||
min_ce_length=0,
|
||||
max_ce_length=0,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1382,9 +1447,9 @@ class Host(utils.EventEmitter):
|
||||
connection_handle=event.connection_handle
|
||||
)
|
||||
|
||||
await self.send_command(response)
|
||||
await self.send_sync_command(response)
|
||||
|
||||
asyncio.create_task(send_long_term_key())
|
||||
utils.AsyncRunner.spawn(send_long_term_key())
|
||||
|
||||
def on_hci_synchronous_connection_complete_event(
|
||||
self, event: hci.HCI_Synchronous_Connection_Complete_Event
|
||||
@@ -1583,9 +1648,9 @@ class Host(utils.EventEmitter):
|
||||
bd_addr=event.bd_addr
|
||||
)
|
||||
|
||||
await self.send_command(response)
|
||||
await self.send_sync_command(response)
|
||||
|
||||
asyncio.create_task(send_link_key())
|
||||
utils.AsyncRunner.spawn(send_link_key())
|
||||
|
||||
def on_hci_io_capability_request_event(
|
||||
self, event: hci.HCI_IO_Capability_Request_Event
|
||||
|
||||
Reference in New Issue
Block a user