Compare commits

..

15 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
3894b14467 better handling of complete/status events 2026-02-02 23:28:40 -08:00
Gilles Boccon-Gibod
e62f947430 add workaround for some buggy controllers 2026-02-02 13:19:55 -08:00
Gilles Boccon-Gibod
dcb8a4b607 Merge pull request #877 from google/gbg/hci-fixes
fix a few HCI types and make the bridge more robust
2026-02-02 11:19:28 -08:00
Gilles Boccon-Gibod
81985c47a9 remove superfluous statement 2026-02-02 11:12:28 -08:00
Gilles Boccon-Gibod
7118328b07 Merge pull request #879 from google/gbg/resolve-when-bonded
resolve addresses when connecting to bonded peers
2026-01-31 11:09:55 -08:00
Josh Wu
c86920558b Merge pull request #878 from zxzxwu/avrcp
AVRCP: SDP record classes and some delegation
2026-01-31 00:01:55 +08:00
Josh Wu
8e6efd0b2f Fix error in AVRCP example 2026-01-30 23:01:11 +08:00
Gilles Boccon-Gibod
2a59e19283 fix comment 2026-01-29 19:09:46 -08:00
Josh Wu
34f5b81c7d AVRCP: Delegate Company ID capabilities 2026-01-29 22:13:14 +08:00
Josh Wu
d34d6a5c98 AVRCP: Delegate Playback Status 2026-01-29 21:33:57 +08:00
Josh Wu
aedc971653 AVRCP: Add SDP record class and finder 2026-01-29 16:00:50 +08:00
Josh Wu
c6815fb820 AVRCP: Delegate passthrough key event 2026-01-29 14:50:14 +08:00
Gilles Boccon-Gibod
f44d013690 make bridge more robust 2026-01-27 09:47:52 -08:00
Gilles Boccon-Gibod
e63dc15ede fix handling of return parameters 2026-01-27 09:39:22 -08:00
Gilles Boccon-Gibod
c901e15666 fix a few HCI types and make the bridge more robust 2026-01-25 13:47:14 -08:00
14 changed files with 913 additions and 338 deletions

View File

@@ -81,7 +81,9 @@ async def async_main():
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci_packet.op_code, command_opcode=hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]), return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode.SUCCESS
),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (bytes(response), True) return (bytes(response), True)

View File

@@ -26,7 +26,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequen
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ClassVar, SupportsBytes, TypeVar from typing import ClassVar, SupportsBytes, TypeVar
from bumble import avc, avctp, core, hci, l2cap, utils from bumble import avc, avctp, core, hci, l2cap, sdp, utils
from bumble.colors import color from bumble.colors import color
from bumble.device import Connection, Device from bumble.device import Connection, Device
from bumble.sdp import ( from bumble.sdp import (
@@ -194,82 +194,43 @@ class TargetFeatures(enum.IntFlag):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_controller_service_sdp_records( @dataclass
service_record_handle: int, class ControllerServiceSdpRecord:
avctp_version: tuple[int, int] = (1, 4), service_record_handle: int
avrcp_version: tuple[int, int] = (1, 6), avctp_version: tuple[int, int] = (1, 4)
supported_features: int | ControllerFeatures = 1, avrcp_version: tuple[int, int] = (1, 6)
) -> list[ServiceAttribute]: supported_features: int | ControllerFeatures = ControllerFeatures(1)
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
attributes = [ def to_service_attributes(self) -> list[ServiceAttribute]:
ServiceAttribute( avctp_version_int = self.avctp_version[0] << 8 | self.avctp_version[1]
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, avrcp_version_int = self.avrcp_version[0] << 8 | self.avrcp_version[1]
DataElement.unsigned_integer_32(service_record_handle),
), attributes = [
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(supported_features),
),
]
if supported_features & ControllerFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute( ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(self.service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.sequence(
[ [
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID), DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16( DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
avctp.AVCTP_BROWSING_PSM
),
] ]
), ),
DataElement.sequence( DataElement.sequence(
@@ -281,87 +242,130 @@ def make_controller_service_sdp_records(
] ]
), ),
), ),
) ServiceAttribute(
return attributes SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(self.supported_features),
),
]
if self.supported_features & ControllerFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(
avctp.AVCTP_BROWSING_PSM
),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
)
return attributes
@classmethod
async def find(cls, connection: Connection) -> list[ControllerServiceSdpRecord]:
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE],
attribute_ids=[
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
],
)
records: list[ControllerServiceSdpRecord] = []
for attribute_lists in search_result:
record = cls(0)
for attribute in attribute_lists:
if attribute.id == SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:
record.service_record_handle = attribute.value.value
elif attribute.id == SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
# [[L2CAP, PSM], [AVCTP, version]]
record.avctp_version = (
attribute.value.value[1].value[1].value >> 8,
attribute.value.value[1].value[1].value & 0xFF,
)
elif (
attribute.id
== SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
):
# [[AV_REMOTE_CONTROL, version]]
record.avrcp_version = (
attribute.value.value[0].value[1].value >> 8,
attribute.value.value[0].value[1].value & 0xFF,
)
elif attribute.id == SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID:
record.supported_features = ControllerFeatures(
attribute.value.value
)
records.append(record)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_target_service_sdp_records( @dataclass
service_record_handle: int, class TargetServiceSdpRecord:
avctp_version: tuple[int, int] = (1, 4), service_record_handle: int
avrcp_version: tuple[int, int] = (1, 6), avctp_version: tuple[int, int] = (1, 4)
supported_features: int | TargetFeatures = 0x23, avrcp_version: tuple[int, int] = (1, 6)
) -> list[ServiceAttribute]: supported_features: int | TargetFeatures = TargetFeatures(0x23)
# TODO: support a way to compute the supported features from a feature list
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
attributes = [ def to_service_attributes(self) -> list[ServiceAttribute]:
ServiceAttribute( # TODO: support a way to compute the supported features from a feature list
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, avctp_version_int = self.avctp_version[0] << 8 | self.avctp_version[1]
DataElement.unsigned_integer_32(service_record_handle), avrcp_version_int = self.avrcp_version[0] << 8 | self.avrcp_version[1]
),
ServiceAttribute( attributes = [
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(supported_features),
),
]
if supported_features & TargetFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute( ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(self.service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.sequence(
[ [
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID), DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16( DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
avctp.AVCTP_BROWSING_PSM
),
] ]
), ),
DataElement.sequence( DataElement.sequence(
@@ -373,8 +377,90 @@ def make_target_service_sdp_records(
] ]
), ),
), ),
) ServiceAttribute(
return attributes SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(self.supported_features),
),
]
if self.supported_features & TargetFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(
avctp.AVCTP_BROWSING_PSM
),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
)
return attributes
@classmethod
async def find(cls, connection: Connection) -> list[TargetServiceSdpRecord]:
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE],
attribute_ids=[
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
],
)
records: list[TargetServiceSdpRecord] = []
for attribute_lists in search_result:
record = cls(0)
for attribute in attribute_lists:
if attribute.id == SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:
record.service_record_handle = attribute.value.value
elif attribute.id == SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
# [[L2CAP, PSM], [AVCTP, version]]
record.avctp_version = (
attribute.value.value[1].value[1].value >> 8,
attribute.value.value[1].value[1].value & 0xFF,
)
elif (
attribute.id
== SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
):
# [[AV_REMOTE_CONTROL, version]]
record.avrcp_version = (
attribute.value.value[0].value[1].value >> 8,
attribute.value.value[0].value[1].value & 0xFF,
)
elif attribute.id == SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID:
record.supported_features = TargetFeatures(
attribute.value.value
)
records.append(record)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1204,6 +1290,10 @@ class InformBatteryStatusOfCtResponse(Response):
@dataclass @dataclass
class GetPlayStatusResponse(Response): class GetPlayStatusResponse(Response):
pdu_id = PduId.GET_PLAY_STATUS pdu_id = PduId.GET_PLAY_STATUS
# TG doesn't support Song Length or Position.
UNAVAILABLE = 0xFFFFFFFF
song_length: int = field(metadata=hci.metadata(">4")) song_length: int = field(metadata=hci.metadata(">4"))
song_position: int = field(metadata=hci.metadata(">4")) song_position: int = field(metadata=hci.metadata(">4"))
play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1))
@@ -1521,16 +1611,33 @@ class Delegate:
def __init__(self, status_code: StatusCode) -> None: def __init__(self, status_code: StatusCode) -> None:
self.status_code = status_code self.status_code = status_code
supported_events: list[EventId] class AvcError(Exception):
volume: int """The delegate AVC method failed, with a specified status code."""
def __init__(self, supported_events: Iterable[EventId] = ()) -> None: def __init__(self, status_code: avc.ResponseFrame.ResponseCode) -> None:
self.status_code = status_code
supported_events: list[EventId]
supported_company_ids: list[int]
volume: int
playback_status: PlayStatus
def __init__(
self,
supported_events: Iterable[EventId] = (),
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
) -> None:
self.supported_company_ids = list(supported_company_ids)
self.supported_events = list(supported_events) self.supported_events = list(supported_events)
self.volume = 0 self.volume = 0
self.playback_status = PlayStatus.STOPPED
async def get_supported_events(self) -> list[EventId]: async def get_supported_events(self) -> list[EventId]:
return self.supported_events return self.supported_events
async def get_supported_company_ids(self) -> list[int]:
return self.supported_company_ids
async def set_absolute_volume(self, volume: int) -> None: async def set_absolute_volume(self, volume: int) -> None:
""" """
Set the absolute volume. Set the absolute volume.
@@ -1543,6 +1650,19 @@ class Delegate:
async def get_absolute_volume(self) -> int: async def get_absolute_volume(self) -> int:
return self.volume return self.volume
async def on_key_event(
self,
key: avc.PassThroughFrame.OperationId,
pressed: bool,
data: bytes,
) -> None:
logger.debug(
"@@@ on_key_event: key=%s, pressed=%s, data=%s", key, pressed, data.hex()
)
async def get_playback_status(self) -> PlayStatus:
return self.playback_status
# TODO add other delegate methods # TODO add other delegate methods
@@ -1756,6 +1876,19 @@ class Protocol(utils.EventEmitter):
if isinstance(capability, EventId) if isinstance(capability, EventId)
) )
async def get_supported_company_ids(self) -> list[int]:
"""Get the list of events supported by the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
GetCapabilitiesCommand(GetCapabilitiesCommand.CapabilityId.COMPANY_ID),
)
response = self._check_response(response_context, GetCapabilitiesResponse)
return list(
int.from_bytes(capability, 'big')
for capability in response.capabilities
if isinstance(capability, bytes)
)
async def get_play_status(self) -> SongAndPlayStatus: async def get_play_status(self) -> SongAndPlayStatus:
"""Get the play status of the connected peer.""" """Get the play status of the connected peer."""
response_context = await self.send_avrcp_command( response_context = await self.send_avrcp_command(
@@ -2052,16 +2185,28 @@ class Protocol(utils.EventEmitter):
return return
if isinstance(command, avc.PassThroughCommandFrame): if isinstance(command, avc.PassThroughCommandFrame):
# TODO: delegate
response = avc.PassThroughResponseFrame( async def dispatch_key_event() -> None:
avc.ResponseFrame.ResponseCode.ACCEPTED, try:
command.subunit_type, await self.delegate.on_key_event(
command.subunit_id, command.operation_id,
command.state_flag, command.state_flag == avc.PassThroughFrame.StateFlag.PRESSED,
command.operation_id, command.operation_data,
command.operation_data, )
) response_code = avc.ResponseFrame.ResponseCode.ACCEPTED
self.send_response(transaction_label, response) except Delegate.AvcError as error:
logger.exception("delegate method raised exception")
response_code = error.status_code
except Exception:
logger.exception("delegate method raised exception")
response_code = avc.ResponseFrame.ResponseCode.REJECTED
self.send_passthrough_response(
transaction_label=transaction_label,
command=command,
response_code=response_code,
)
utils.AsyncRunner.spawn(dispatch_key_event())
return return
# TODO handle other types # TODO handle other types
@@ -2141,6 +2286,8 @@ class Protocol(utils.EventEmitter):
self._on_set_absolute_volume_command(transaction_label, command) self._on_set_absolute_volume_command(transaction_label, command)
elif isinstance(command, RegisterNotificationCommand): elif isinstance(command, RegisterNotificationCommand):
self._on_register_notification_command(transaction_label, command) self._on_register_notification_command(transaction_label, command)
elif isinstance(command, GetPlayStatusCommand):
self._on_get_play_status_command(transaction_label, command)
else: else:
# Not supported. # Not supported.
# TODO: check that this is the right way to respond in this case. # TODO: check that this is the right way to respond in this case.
@@ -2364,17 +2511,27 @@ class Protocol(utils.EventEmitter):
logger.debug(f"<<< AVRCP command PDU: {command}") logger.debug(f"<<< AVRCP command PDU: {command}")
async def get_supported_events() -> None: async def get_supported_events() -> None:
capabilities: Sequence[bytes | SupportsBytes]
if ( if (
command.capability_id command.capability_id
!= GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
): ):
raise core.InvalidArgumentError() capabilities = await self.delegate.get_supported_events()
elif (
supported_events = await self.delegate.get_supported_events() command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID
):
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
else:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
self.send_avrcp_response( self.send_avrcp_response(
transaction_label, transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetCapabilitiesResponse(command.capability_id, supported_events), GetCapabilitiesResponse(command.capability_id, capabilities),
) )
self._delegate_command(transaction_label, command, get_supported_events()) self._delegate_command(transaction_label, command, get_supported_events())
@@ -2395,6 +2552,26 @@ class Protocol(utils.EventEmitter):
self._delegate_command(transaction_label, command, set_absolute_volume()) self._delegate_command(transaction_label, command, set_absolute_volume())
def _on_get_play_status_command(
self, transaction_label: int, command: GetPlayStatusCommand
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_playback_status() -> None:
play_status: PlayStatus = await self.delegate.get_playback_status()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetPlayStatusResponse(
# TODO: Delegate this.
song_length=GetPlayStatusResponse.UNAVAILABLE,
song_position=GetPlayStatusResponse.UNAVAILABLE,
play_status=play_status,
),
)
self._delegate_command(transaction_label, command, get_playback_status())
def _on_register_notification_command( def _on_register_notification_command(
self, transaction_label: int, command: RegisterNotificationCommand self, transaction_label: int, command: RegisterNotificationCommand
) -> None: ) -> None:
@@ -2410,28 +2587,27 @@ class Protocol(utils.EventEmitter):
) )
return return
response: Response
if command.event_id == EventId.VOLUME_CHANGED: if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume() volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume)) response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response( elif command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
transaction_label, playback_status = await self.delegate.get_playback_status()
avc.ResponseFrame.ResponseCode.INTERIM, response = RegisterNotificationResponse(
response, PlaybackStatusChangedEvent(play_status=playback_status)
) )
self._register_notification_listener(transaction_label, command) elif command.event_id == EventId.NOW_PLAYING_CONTENT_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(NowPlayingContentChangedEvent())
else:
logger.warning("Event supported but not handled %s", command.event_id)
return return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: self.send_avrcp_response(
# TODO: testing only, use delegate transaction_label,
response = RegisterNotificationResponse( avc.ResponseFrame.ResponseCode.INTERIM,
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING) response,
) )
self.send_avrcp_response( self._register_notification_listener(transaction_label, command)
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
self._delegate_command(transaction_label, command, register_notification()) self._delegate_command(transaction_label, command, register_notification())

View File

@@ -37,7 +37,12 @@ class HCI_Bridge:
def on_packet(self, packet): def on_packet(self, packet):
# Convert the packet bytes to an object # Convert the packet bytes to an object
hci_packet = HCI_Packet.from_bytes(packet) try:
hci_packet = HCI_Packet.from_bytes(packet)
except Exception:
logger.warning('forwarding unparsed packet as-is')
self.hci_sink.on_packet(packet)
return
# Filter the packet # Filter the packet
if self.packet_filter is not None: if self.packet_filter is not None:
@@ -50,7 +55,10 @@ class HCI_Bridge:
return return
# Analyze the packet # Analyze the packet
self.trace(hci_packet) try:
self.trace(hci_packet)
except Exception:
logger.exception('Exception while tracing packet')
# Bridge the packet # Bridge the packet
self.hci_sink.on_packet(packet) self.hci_sink.on_packet(packet)

View File

@@ -1177,7 +1177,7 @@ class ChannelSoundingCapabilities:
rtt_capability: int rtt_capability: int
rtt_aa_only_n: int rtt_aa_only_n: int
rtt_sounding_n: int rtt_sounding_n: int
rtt_random_payload_n: int rtt_random_sequence_n: int
nadm_sounding_capability: int nadm_sounding_capability: int
nadm_random_capability: int nadm_random_capability: int
cs_sync_phys_supported: int cs_sync_phys_supported: int
@@ -2763,24 +2763,39 @@ class Device(utils.CompositeEventEmitter):
logger.warning(f'!!! Command {command.name} timed out') logger.warning(f'!!! Command {command.name} timed out')
raise CommandTimeoutError() from error raise CommandTimeoutError() from error
async def send_sync_command( async def send_sync_command(self, command: hci.HCI_SyncCommand[_RP]) -> _RP:
self, command: hci.HCI_SyncCommand[_RP], check_status: bool = True
) -> _RP:
''' '''
Send a synchronous command via the host. Send a synchronous command via the host.
If the `status` field of the response's `return_parameters` is not equal to
`SUCCESS` an exception is raised.
Params: Params:
command: the command to send. command: the command to send.
check_status: If `True`, check the `status` field of the response's
`return_parameters` and raise and exception if not equal to `SUCCESS`.
Returns: Returns:
An instance of the return parameters class associated with the command class. An instance of the return parameters class associated with the command class.
''' '''
try: try:
return await self.host.send_sync_command( return await self.host.send_sync_command(command, self.command_timeout)
command, check_status, self.command_timeout except asyncio.TimeoutError as error:
) logger.warning(f'!!! Command {command.name} timed out')
raise CommandTimeoutError() from error
async def send_sync_command_raw(
self, command: hci.HCI_SyncCommand[_RP]
) -> hci.HCI_Command_Complete_Event[_RP]:
'''
Send a synchronous command via the host without checking the response.
Params:
command: the command to send.
Returns:
An HCI_Command_Complete_Event instance.
'''
try:
return await self.host.send_sync_command_raw(command, self.command_timeout)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(f'!!! Command {command.name} timed out') logger.warning(f'!!! Command {command.name} timed out')
raise CommandTimeoutError() from error raise CommandTimeoutError() from error
@@ -2797,7 +2812,7 @@ class Device(utils.CompositeEventEmitter):
raise and exception if not equal to `PENDING`. raise and exception if not equal to `PENDING`.
Returns: Returns:
An instance of the return parameters class associated with the command class. A status code.
''' '''
try: try:
return await self.host.send_async_command( return await self.host.send_async_command(
@@ -2812,12 +2827,12 @@ class Device(utils.CompositeEventEmitter):
await self.host.reset() await self.host.reset()
# Try to get the public address from the controller # Try to get the public address from the controller
response = await self.host.send_sync_command( try:
hci.HCI_Read_BD_ADDR_Command(), check_status=False response = await self.host.send_sync_command(hci.HCI_Read_BD_ADDR_Command())
)
if response.status == hci.HCI_SUCCESS:
logger.debug(color(f'BD_ADDR: {response.bd_addr}', 'yellow')) logger.debug(color(f'BD_ADDR: {response.bd_addr}', 'yellow'))
self.public_address = response.bd_addr self.public_address = response.bd_addr
except hci.HCI_Error:
logger.debug('Controller has no public address')
# Instantiate the Key Store (we do this here rather than at __init__ time # Instantiate the Key Store (we do this here rather than at __init__ time
# because some Key Store implementations use the public address as a namespace) # because some Key Store implementations use the public address as a namespace)
@@ -2926,7 +2941,7 @@ class Device(utils.CompositeEventEmitter):
rtt_capability=result.rtt_capability, rtt_capability=result.rtt_capability,
rtt_aa_only_n=result.rtt_aa_only_n, rtt_aa_only_n=result.rtt_aa_only_n,
rtt_sounding_n=result.rtt_sounding_n, rtt_sounding_n=result.rtt_sounding_n,
rtt_random_payload_n=result.rtt_random_payload_n, rtt_random_sequence_n=result.rtt_random_sequence_n,
nadm_sounding_capability=result.nadm_sounding_capability, nadm_sounding_capability=result.nadm_sounding_capability,
nadm_random_capability=result.nadm_random_capability, nadm_random_capability=result.nadm_random_capability,
cs_sync_phys_supported=result.cs_sync_phys_supported, cs_sync_phys_supported=result.cs_sync_phys_supported,
@@ -2954,27 +2969,23 @@ class Device(utils.CompositeEventEmitter):
) )
if self.classic_enabled: if self.classic_enabled:
await self.send_sync_command( await self.send_sync_command_raw(
hci.HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')), hci.HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8'))
check_status=False,
) )
await self.send_sync_command( await self.send_sync_command_raw(
hci.HCI_Write_Class_Of_Device_Command( hci.HCI_Write_Class_Of_Device_Command(
class_of_device=self.class_of_device class_of_device=self.class_of_device
), )
check_status=False,
) )
await self.send_sync_command( await self.send_sync_command_raw(
hci.HCI_Write_Simple_Pairing_Mode_Command( hci.HCI_Write_Simple_Pairing_Mode_Command(
simple_pairing_mode=int(self.classic_ssp_enabled) simple_pairing_mode=int(self.classic_ssp_enabled)
), )
check_status=False,
) )
await self.send_sync_command( await self.send_sync_command_raw(
hci.HCI_Write_Secure_Connections_Host_Support_Command( hci.HCI_Write_Secure_Connections_Host_Support_Command(
secure_connections_host_support=int(self.classic_sc_enabled) secure_connections_host_support=int(self.classic_sc_enabled)
), )
check_status=False,
) )
await self.set_connectable(self.connectable) await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable) await self.set_discoverable(self.discoverable)
@@ -6719,7 +6730,7 @@ class Device(utils.CompositeEventEmitter):
rtt_capability=event.rtt_capability, rtt_capability=event.rtt_capability,
rtt_aa_only_n=event.rtt_aa_only_n, rtt_aa_only_n=event.rtt_aa_only_n,
rtt_sounding_n=event.rtt_sounding_n, rtt_sounding_n=event.rtt_sounding_n,
rtt_random_payload_n=event.rtt_random_payload_n, rtt_random_sequence_n=event.rtt_random_sequence_n,
nadm_sounding_capability=event.nadm_sounding_capability, nadm_sounding_capability=event.nadm_sounding_capability,
nadm_random_capability=event.nadm_random_capability, nadm_random_capability=event.nadm_random_capability,
cs_sync_phys_supported=event.cs_sync_phys_supported, cs_sync_phys_supported=event.cs_sync_phys_supported,

View File

@@ -663,10 +663,13 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]: async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True self.host.ready = True
response1 = await self.host.send_sync_command( response1 = await self.host.send_sync_command_raw(hci.HCI_Reset_Command())
hci.HCI_Reset_Command(), check_status=False if not isinstance(
) response1.return_parameters, hci.HCI_StatusReturnParameters
if response1.status not in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS): ) or response1.return_parameters.status not in (
hci.HCI_UNKNOWN_HCI_COMMAND_ERROR,
hci.HCI_SUCCESS,
):
# When the controller is in operational mode, the response is a # When the controller is in operational mode, the response is a
# successful response. # successful response.
# When the controller is in bootloader mode, # When the controller is in bootloader mode,
@@ -676,13 +679,18 @@ class Driver(common.Driver):
raise DriverError("unexpected HCI response") raise DriverError("unexpected HCI response")
# Read the firmware version. # Read the firmware version.
response2 = await self.host.send_sync_command( response2 = await self.host.send_sync_command_raw(
HCI_Intel_Read_Version_Command(param0=0xFF), check_status=False HCI_Intel_Read_Version_Command(param0=0xFF)
) )
if response2.status != 0: # type: ignore if (
not isinstance(
response2.return_parameters, HCI_Intel_Read_Version_ReturnParameters
)
or response2.return_parameters.status != 0
):
raise DriverError("HCI_Intel_Read_Version_Command error") raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response2.tlv) # type: ignore tlvs = _parse_tlv(response2.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type # Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once. # to appear just once.

View File

@@ -534,11 +534,13 @@ class Driver(common.Driver):
@staticmethod @staticmethod
async def get_loaded_firmware_version(host: Host) -> int | None: async def get_loaded_firmware_version(host: Host) -> int | None:
response1 = await host.send_sync_command( response1 = await host.send_sync_command_raw(HCI_RTK_Read_ROM_Version_Command())
HCI_RTK_Read_ROM_Version_Command(), check_status=False if (
) not isinstance(
response1.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
if response1.status != hci.HCI_SUCCESS: )
or response1.return_parameters.status != hci.HCI_SUCCESS
):
return None return None
response2 = await host.send_sync_command( response2 = await host.send_sync_command(
@@ -559,13 +561,20 @@ class Driver(common.Driver):
await host.send_sync_command(hci.HCI_Reset_Command()) await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command() response = await host.send_sync_command_raw(
response = await host.send_sync_command(command, check_status=False) hci.HCI_Read_Local_Version_Information_Command()
if response.status != hci.HCI_SUCCESS: )
if (
not isinstance(
response.return_parameters,
hci.HCI_Read_Local_Version_Information_ReturnParameters,
)
or response.return_parameters.status != hci.HCI_SUCCESS
):
logger.error("failed to probe local version information") logger.error("failed to probe local version information")
return None return None
local_version = response local_version = response.return_parameters
logger.debug( logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} " f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
@@ -641,15 +650,21 @@ class Driver(common.Driver):
# TODO: load the firmware # TODO: load the firmware
async def download_for_rtl8723b(self): async def download_for_rtl8723b(self) -> int | None:
if self.driver_info.has_rom_version: if self.driver_info.has_rom_version:
response1 = await self.host.send_sync_command( response1 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command(), check_status=False HCI_RTK_Read_ROM_Version_Command()
) )
if response1.status != hci.HCI_SUCCESS: if (
not isinstance(
response1.return_parameters,
HCI_RTK_Read_ROM_Version_ReturnParameters,
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
return None return None
rom_version = response1.version rom_version = response1.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}") logger.debug(f"ROM version before download: {rom_version:04X}")
else: else:
rom_version = 0 rom_version = 0
@@ -691,13 +706,18 @@ class Driver(common.Driver):
logger.debug("download complete!") logger.debug("download complete!")
# Read the version again # Read the version again
response2 = await self.host.send_sync_command( response2 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command(), check_status=False HCI_RTK_Read_ROM_Version_Command()
) )
if response2.status != hci.HCI_SUCCESS: if (
not isinstance(
response2.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response2.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
else: else:
rom_version = response2.version rom_version = response2.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:02X}") logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version return firmware.version

View File

@@ -2407,24 +2407,28 @@ class HCI_Packet:
@classmethod @classmethod
def from_bytes(cls, packet: bytes) -> HCI_Packet: def from_bytes(cls, packet: bytes) -> HCI_Packet:
packet_type = packet[0] try:
packet_type = packet[0]
if packet_type == HCI_COMMAND_PACKET: if packet_type == HCI_COMMAND_PACKET:
return HCI_Command.from_bytes(packet) return HCI_Command.from_bytes(packet)
if packet_type == HCI_ACL_DATA_PACKET: if packet_type == HCI_ACL_DATA_PACKET:
return HCI_AclDataPacket.from_bytes(packet) return HCI_AclDataPacket.from_bytes(packet)
if packet_type == HCI_SYNCHRONOUS_DATA_PACKET: if packet_type == HCI_SYNCHRONOUS_DATA_PACKET:
return HCI_SynchronousDataPacket.from_bytes(packet) return HCI_SynchronousDataPacket.from_bytes(packet)
if packet_type == HCI_EVENT_PACKET: if packet_type == HCI_EVENT_PACKET:
return HCI_Event.from_bytes(packet) return HCI_Event.from_bytes(packet)
if packet_type == HCI_ISO_DATA_PACKET: if packet_type == HCI_ISO_DATA_PACKET:
return HCI_IsoDataPacket.from_bytes(packet) return HCI_IsoDataPacket.from_bytes(packet)
return HCI_CustomPacket(packet) return HCI_CustomPacket(packet)
except Exception as e:
logger.error(f'error parsing HCI packet [{packet.hex()}]: {e}')
raise
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
self.name = name self.name = name
@@ -2597,6 +2601,21 @@ class HCI_GenericReturnParameters(HCI_ReturnParameters):
class HCI_StatusReturnParameters(HCI_ReturnParameters): class HCI_StatusReturnParameters(HCI_ReturnParameters):
status: HCI_ErrorCode = field(metadata=HCI_ErrorCode.type_metadata(1)) status: HCI_ErrorCode = field(metadata=HCI_ErrorCode.type_metadata(1))
@classmethod
def from_parameters(cls, parameters: bytes) -> Self | HCI_StatusReturnParameters:
status = HCI_ErrorCode(parameters[0])
if status != HCI_ErrorCode.SUCCESS:
# Don't parse further, just return the status.
return HCI_StatusReturnParameters(status=status)
return cls(**HCI_Object.dict_from_bytes(parameters, 0, cls.fields))
@dataclasses.dataclass
class HCI_GenericStatusReturnParameters(HCI_StatusReturnParameters):
data: bytes = field(metadata=metadata('*'))
@dataclasses.dataclass @dataclasses.dataclass
class HCI_StatusAndAddressReturnParameters(HCI_StatusReturnParameters): class HCI_StatusAndAddressReturnParameters(HCI_StatusReturnParameters):
@@ -5854,7 +5873,7 @@ class HCI_LE_CS_Read_Local_Supported_Capabilities_ReturnParameters(
rtt_capability: int = field(metadata=metadata(1)) rtt_capability: int = field(metadata=metadata(1))
rtt_aa_only_n: int = field(metadata=metadata(1)) rtt_aa_only_n: int = field(metadata=metadata(1))
rtt_sounding_n: int = field(metadata=metadata(1)) rtt_sounding_n: int = field(metadata=metadata(1))
rtt_random_payload_n: int = field(metadata=metadata(1)) rtt_random_sequence_n: int = field(metadata=metadata(1))
nadm_sounding_capability: int = field(metadata=metadata(2)) nadm_sounding_capability: int = field(metadata=metadata(2))
nadm_random_capability: int = field(metadata=metadata(2)) nadm_random_capability: int = field(metadata=metadata(2))
cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC)) cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC))
@@ -5910,7 +5929,7 @@ class HCI_LE_CS_Write_Cached_Remote_Supported_Capabilities_Command(
rtt_capability: int = field(metadata=metadata(1)) rtt_capability: int = field(metadata=metadata(1))
rtt_aa_only_n: int = field(metadata=metadata(1)) rtt_aa_only_n: int = field(metadata=metadata(1))
rtt_sounding_n: int = field(metadata=metadata(1)) rtt_sounding_n: int = field(metadata=metadata(1))
rtt_random_payload_n: int = field(metadata=metadata(1)) rtt_random_sequence_n: int = field(metadata=metadata(1))
nadm_sounding_capability: int = field(metadata=metadata(2)) nadm_sounding_capability: int = field(metadata=metadata(2))
nadm_random_capability: int = field(metadata=metadata(2)) nadm_random_capability: int = field(metadata=metadata(2))
cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC)) cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC))
@@ -7118,7 +7137,7 @@ class HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event(HCI_LE_Meta_Ev
rtt_capability: int = field(metadata=metadata(1)) rtt_capability: int = field(metadata=metadata(1))
rtt_aa_only_n: int = field(metadata=metadata(1)) rtt_aa_only_n: int = field(metadata=metadata(1))
rtt_sounding_n: int = field(metadata=metadata(1)) rtt_sounding_n: int = field(metadata=metadata(1))
rtt_random_payload_n: int = field(metadata=metadata(1)) rtt_random_sequence_n: int = field(metadata=metadata(1))
nadm_sounding_capability: int = field(metadata=metadata(2)) nadm_sounding_capability: int = field(metadata=metadata(2))
nadm_random_capability: int = field(metadata=metadata(2)) nadm_random_capability: int = field(metadata=metadata(2))
cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC)) cs_sync_phys_supported: int = field(metadata=metadata(CS_SYNC_PHY_SUPPORTED_SPEC))
@@ -7494,6 +7513,7 @@ class HCI_Command_Complete_Event(HCI_Event, Generic[_RP]):
def from_parameters(cls, parameters: bytes) -> Self: def from_parameters(cls, parameters: bytes) -> Self:
event = cls(**HCI_Object.dict_from_bytes(parameters, 0, cls.fields)) event = cls(**HCI_Object.dict_from_bytes(parameters, 0, cls.fields))
event.parameters = parameters event.parameters = parameters
return_parameters_bytes = parameters[3:]
# Find the class for the matching command. # Find the class for the matching command.
subclass = HCI_Command.command_classes.get(event.command_opcode) subclass = HCI_Command.command_classes.get(event.command_opcode)
@@ -7506,16 +7526,16 @@ class HCI_Command_Complete_Event(HCI_Event, Generic[_RP]):
'HCI Command Complete event with opcode for a class that is not' 'HCI Command Complete event with opcode for a class that is not'
' an HCI_SyncCommand subclass: ' ' an HCI_SyncCommand subclass: '
f'opcode={event.command_opcode:#04x}, ' f'opcode={event.command_opcode:#04x}, '
f'type={type(subclass).__name__}' f'type={subclass.__name__}'
) )
event.return_parameters = HCI_GenericReturnParameters( event.return_parameters = HCI_GenericReturnParameters(
data=event.return_parameters # type: ignore[arg-type] data=return_parameters_bytes
) # type: ignore[assignment] ) # type: ignore[assignment]
return event return event
# Parse the return parameters bytes into an object. # Parse the return parameters bytes into an object.
event.return_parameters = subclass.parse_return_parameters( event.return_parameters = subclass.parse_return_parameters(
event.return_parameters # type: ignore[arg-type] return_parameters_bytes
) # type: ignore[assignment] ) # type: ignore[assignment]
return event return event

View File

@@ -270,7 +270,12 @@ class Host(utils.EventEmitter):
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles self.bigs = {} # BIG Handle to BIS Handles
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: asyncio.Future[Any] | None = None self.pending_response: (
asyncio.Future[
hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event
]
| None
) = None
self.number_of_supported_advertising_sets = 0 self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31 self.maximum_advertising_data_length = 31
self.local_version: ( self.local_version: (
@@ -611,22 +616,28 @@ class Host(utils.EventEmitter):
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
): ):
response10 = await self.send_sync_command( try:
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command() response10 = await self.send_sync_command(
) hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
self.number_of_supported_advertising_sets = ( )
response10.num_supported_advertising_sets self.number_of_supported_advertising_sets = (
) response10.num_supported_advertising_sets
)
except hci.HCI_Error:
logger.warning('Failed to read number of supported advertising sets')
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
): ):
response11 = await self.send_sync_command( try:
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command() response11 = await self.send_sync_command(
) hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
self.maximum_advertising_data_length = ( )
response11.max_advertising_data_length self.maximum_advertising_data_length = (
) response11.max_advertising_data_length
)
except hci.HCI_Error:
logger.warning('Failed to read maximum advertising data length')
@property @property
def controller(self) -> TransportSink | None: def controller(self) -> TransportSink | None:
@@ -658,25 +669,35 @@ class Host(utils.EventEmitter):
response_timeout: float | None = None, response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event: ) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
async with self.command_semaphore: await self.command_semaphore.acquire()
assert self.pending_command is None
assert self.pending_response is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future() assert self.pending_command is None
self.pending_command = command assert self.pending_response is None
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_command = command
try: response: (
self.send_hci_packet(command) hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event | None
return await asyncio.wait_for( ) = None
self.pending_response, timeout=response_timeout try:
) self.send_hci_packet(command)
except Exception: response = await asyncio.wait_for(
logger.exception(color("!!! Exception while sending command:", "red")) self.pending_response, timeout=response_timeout
raise )
finally: return response
self.pending_command = None except Exception:
self.pending_response = None logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
):
self.command_semaphore.release()
@overload @overload
async def send_command( async def send_command(
@@ -729,30 +750,56 @@ class Host(utils.EventEmitter):
return response return response
async def send_sync_command( async def send_sync_command(
self, command: hci.HCI_SyncCommand[_RP], response_timeout: float | None = None
) -> _RP:
response = await self.send_sync_command_raw(command, response_timeout)
return_parameters = response.return_parameters
# Check the return parameters's status
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
elif isinstance(return_parameters, hci.HCI_GenericReturnParameters):
# if the payload has at least one byte, assume the first byte is the status
if not return_parameters.data:
raise RuntimeError('no status byte in return parameters')
status = hci.HCI_ErrorCode(return_parameters.data[0])
else:
raise RuntimeError(
f'unexpected return parameters type ({type(return_parameters)})'
)
if status != hci.HCI_ErrorCode.SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_sync_command_raw(
self, self,
command: hci.HCI_SyncCommand[_RP], command: hci.HCI_SyncCommand[_RP],
check_status: bool = True,
response_timeout: float | None = None, response_timeout: float | None = None,
) -> _RP: ) -> hci.HCI_Command_Complete_Event[_RP]:
response = await self._send_command(command, response_timeout) response = await self._send_command(command, response_timeout)
# For unknown HCI commands, some controllers return Command Status instead of
# Command Complete.
if (
isinstance(response, hci.HCI_Command_Status_Event)
and response.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
):
return hci.HCI_Command_Complete_Event(
num_hci_command_packets=response.num_hci_command_packets,
command_opcode=command.op_code,
return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode(response.status)
), # type: ignore
)
# Check that the response is of the expected type # Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event) assert isinstance(response, hci.HCI_Command_Complete_Event)
return_parameters: _RP = response.return_parameters
assert isinstance(return_parameters, command.return_parameters_class)
# Check the return parameters if required return response
if check_status:
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_async_command( async def send_async_command(
self, self,
@@ -762,19 +809,25 @@ class Host(utils.EventEmitter):
) -> hci.HCI_ErrorCode: ) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout) response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type # For unknown HCI commands, some controllers return Command Complete instead of
assert isinstance(response, hci.HCI_Command_Status_Event) # Command Status.
if isinstance(response, hci.HCI_Command_Complete_Event):
# Assume the first byte of the return parameters is the status
if (
status := hci.HCI_ErrorCode(response.parameters[3])
) != hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR:
logger.warning(f'unexpected return paramerers status {status}')
else:
assert isinstance(response, hci.HCI_Command_Status_Event)
status = hci.HCI_ErrorCode(response.status)
# Check the return parameters if required # Check the status if required
status = response.status
if check_status: if check_status:
if status != hci.HCI_CommandStatus.PENDING: if status != hci.HCI_CommandStatus.PENDING:
logger.warning( logger.warning(f'{command.name} failed ' f'({status.name})')
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status) raise hci.HCI_Error(status)
return hci.HCI_ErrorCode(status) return status
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.") @utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None: def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
@@ -1003,6 +1056,8 @@ class Host(utils.EventEmitter):
self.pending_response.set_result(event) self.pending_response.set_result(event)
else: else:
logger.warning('!!! no pending response future to set') logger.warning('!!! no pending response future to set')
if event.num_hci_command_packets and self.command_semaphore.locked():
self.command_semaphore.release()
############################################################ ############################################################
# HCI handlers # HCI handlers
@@ -1014,7 +1069,13 @@ class Host(utils.EventEmitter):
if event.command_opcode == 0: if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to # This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command # an actual command
logger.debug('no-command event') logger.debug('no-command event for flow control')
# Release the command semaphore if needed
if event.num_hci_command_packets and self.command_semaphore.locked():
logger.debug('command complete event releasing semaphore')
self.command_semaphore.release()
return return
return self.on_command_processed(event) return self.on_command_processed(event)

View File

@@ -194,7 +194,7 @@ async def open_android_netsim_controller_transport(
# We only accept BLUETOOTH # We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH: if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type') logger.debug('Request for unsupported chip type')
error = PacketResponse(error='Unsupported chip type') error = PacketResponse(error='Unsupported chip type')
await self.context.write(error) await self.context.write(error)
# return # return

View File

@@ -42,8 +42,7 @@ response = await host.send_sync_command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV, handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0, connection_handle=0,
tx_power_level=-4, tx_power_level=-4,
), )
check_status=False
) )
if response.status == HCI_SUCCESS: if response.status == HCI_SUCCESS:

View File

@@ -25,7 +25,7 @@ import sys
import websockets.asyncio.server import websockets.asyncio.server
import bumble.logging import bumble.logging
from bumble import a2dp, avc, avdtp, avrcp, utils from bumble import a2dp, avc, avdtp, avrcp, sdp, utils
from bumble.core import PhysicalTransport from bumble.core import PhysicalTransport
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport from bumble.transport import open_transport
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(): def sdp_records() -> dict[int, list[sdp.ServiceAttribute]]:
a2dp_sink_service_record_handle = 0x00010001 a2dp_sink_service_record_handle = 0x00010001
avrcp_controller_service_record_handle = 0x00010002 avrcp_controller_service_record_handle = 0x00010002
avrcp_target_service_record_handle = 0x00010003 avrcp_target_service_record_handle = 0x00010003
@@ -43,17 +43,17 @@ def sdp_records():
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records( a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
a2dp_sink_service_record_handle a2dp_sink_service_record_handle
), ),
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records( avrcp_controller_service_record_handle: avrcp.ControllerServiceSdpRecord(
avrcp_controller_service_record_handle avrcp_controller_service_record_handle
), ).to_service_attributes(),
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records( avrcp_target_service_record_handle: avrcp.TargetServiceSdpRecord(
avrcp_controller_service_record_handle avrcp_target_service_record_handle
), ).to_service_attributes(),
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def codec_capabilities(): def codec_capabilities() -> avdtp.MediaCodecCapabilities:
return avdtp.MediaCodecCapabilities( return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE, media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE, media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
@@ -81,20 +81,22 @@ def codec_capabilities():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection(server): def on_avdtp_connection(server: avdtp.Protocol) -> None:
# Add a sink endpoint to the server # Add a sink endpoint to the server
sink = server.add_sink(codec_capabilities()) sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet) sink.on(sink.EVENT_RTP_PACKET, on_rtp_packet)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_rtp_packet(packet): def on_rtp_packet(packet: avdtp.MediaPacket) -> None:
print(f'RTP: {packet}') print(f'RTP: {packet}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer): def on_avrcp_start(
async def get_supported_events(): avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer
) -> None:
async def get_supported_events() -> None:
events = await avrcp_protocol.get_supported_events() events = await avrcp_protocol.get_supported_events()
print("SUPPORTED EVENTS:", events) print("SUPPORTED EVENTS:", events)
websocket_server.send_message( websocket_server.send_message(
@@ -130,14 +132,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
utils.AsyncRunner.spawn(get_supported_events()) utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed(): async def monitor_track_changed() -> None:
async for identifier in avrcp_protocol.monitor_track_changed(): async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex()) print("TRACK CHANGED:", identifier.hex())
websocket_server.send_message( websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}} {"type": "track-changed", "params": {"identifier": identifier.hex()}}
) )
async def monitor_playback_status(): async def monitor_playback_status() -> None:
async for playback_status in avrcp_protocol.monitor_playback_status(): async for playback_status in avrcp_protocol.monitor_playback_status():
print("PLAYBACK STATUS CHANGED:", playback_status.name) print("PLAYBACK STATUS CHANGED:", playback_status.name)
websocket_server.send_message( websocket_server.send_message(
@@ -147,7 +149,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_playback_position(): async def monitor_playback_position() -> None:
async for playback_position in avrcp_protocol.monitor_playback_position( async for playback_position in avrcp_protocol.monitor_playback_position(
playback_interval=1 playback_interval=1
): ):
@@ -159,7 +161,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_player_application_settings(): async def monitor_player_application_settings() -> None:
async for settings in avrcp_protocol.monitor_player_application_settings(): async for settings in avrcp_protocol.monitor_player_application_settings():
print("PLAYER APPLICATION SETTINGS:", settings) print("PLAYER APPLICATION SETTINGS:", settings)
settings_as_dict = [ settings_as_dict = [
@@ -173,14 +175,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_available_players(): async def monitor_available_players() -> None:
async for _ in avrcp_protocol.monitor_available_players(): async for _ in avrcp_protocol.monitor_available_players():
print("AVAILABLE PLAYERS CHANGED") print("AVAILABLE PLAYERS CHANGED")
websocket_server.send_message( websocket_server.send_message(
{"type": "available-players-changed", "params": {}} {"type": "available-players-changed", "params": {}}
) )
async def monitor_addressed_player(): async def monitor_addressed_player() -> None:
async for player in avrcp_protocol.monitor_addressed_player(): async for player in avrcp_protocol.monitor_addressed_player():
print("ADDRESSED PLAYER CHANGED") print("ADDRESSED PLAYER CHANGED")
websocket_server.send_message( websocket_server.send_message(
@@ -195,7 +197,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_uids(): async def monitor_uids() -> None:
async for uid_counter in avrcp_protocol.monitor_uids(): async for uid_counter in avrcp_protocol.monitor_uids():
print("UIDS CHANGED") print("UIDS CHANGED")
websocket_server.send_message( websocket_server.send_message(
@@ -207,7 +209,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_volume(): async def monitor_volume() -> None:
async for volume in avrcp_protocol.monitor_volume(): async for volume in avrcp_protocol.monitor_volume():
print("VOLUME CHANGED:", volume) print("VOLUME CHANGED:", volume)
websocket_server.send_message( websocket_server.send_message(
@@ -360,7 +362,7 @@ async def main() -> None:
# Create a listener to wait for AVDTP connections # Create a listener to wait for AVDTP connections
listener = avdtp.Listener(avdtp.Listener.create_registrar(device)) listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
listener.on('connection', on_avdtp_connection) listener.on(listener.EVENT_CONNECTION, on_avdtp_connection)
avrcp_delegate = Delegate() avrcp_delegate = Delegate()
avrcp_protocol = avrcp.Protocol(avrcp_delegate) avrcp_protocol = avrcp.Protocol(avrcp_delegate)

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import struct import struct
from collections.abc import Sequence from collections.abc import Sequence
@@ -422,6 +423,47 @@ def test_passthrough_commands():
assert bytes(parsed) == play_pressed_bytes assert bytes(parsed) == play_pressed_bytes
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_sdp_records():
two_devices = await TwoDevices.create_with_avdtp()
# Add SDP records to device 1
controller_record = avrcp.ControllerServiceSdpRecord(
service_record_handle=0x10001,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.ControllerFeatures.CATEGORY_1
| avrcp.ControllerFeatures.SUPPORTS_BROWSING
),
)
target_record = avrcp.TargetServiceSdpRecord(
service_record_handle=0x10002,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.TargetFeatures.CATEGORY_1 | avrcp.TargetFeatures.SUPPORTS_BROWSING
),
)
two_devices.devices[1].sdp_service_records = {
0x10001: controller_record.to_service_attributes(),
0x10002: target_record.to_service_attributes(),
}
# Find records from device 0
controller_records = await avrcp.ControllerServiceSdpRecord.find(
two_devices.connections[0]
)
assert len(controller_records) == 1
assert controller_records[0] == controller_record
target_records = await avrcp.TargetServiceSdpRecord.find(two_devices.connections[0])
assert len(target_records) == 1
assert target_records[0] == target_record
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_supported_events(): async def test_get_supported_events():
@@ -436,6 +478,163 @@ async def test_get_supported_events():
assert supported_events == [avrcp.EventId.VOLUME_CHANGED] assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event():
two_devices = await TwoDevices.create_with_avdtp()
q = asyncio.Queue[tuple[avc.PassThroughFrame.OperationId, bool, bytes]]()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
q.put_nowait((key, pressed, data))
two_devices.protocols[1].delegate = Delegate()
for key, pressed in [
(avc.PassThroughFrame.OperationId.PLAY, True),
(avc.PassThroughFrame.OperationId.PLAY, False),
(avc.PassThroughFrame.OperationId.PAUSE, True),
(avc.PassThroughFrame.OperationId.PAUSE, False),
]:
await two_devices.protocols[0].send_key_event(key, pressed)
assert (await q.get()) == (key, pressed, b'')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_rejected():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise avrcp.Delegate.AvcError(avc.ResponseFrame.ResponseCode.REJECTED)
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_exception():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise Exception()
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_volume():
two_devices = await TwoDevices.create_with_avdtp()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
response = await two_devices.protocols[1].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL, avrcp.SetAbsoluteVolumeCommand(volume)
)
assert isinstance(response.response, avrcp.SetAbsoluteVolumeResponse)
assert response.response.volume == volume
assert two_devices.protocols[0].delegate.volume == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate.playback_status = status
response = await two_devices.protocols[1].get_play_status()
assert response.play_status == status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_company_ids():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate = avrcp.Delegate(
supported_company_ids=[avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
)
supported_company_ids = await two_devices.protocols[
1
].get_supported_company_ids()
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_volume():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
volume_iter = two_devices.protocols[0].monitor_volume()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
# Interim
two_devices.protocols[1].delegate.volume = 0
assert (await anext(volume_iter)) == 0
# Changed
two_devices.protocols[1].notify_volume_changed(volume)
assert (await anext(volume_iter)) == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.PLAYBACK_STATUS_CHANGED]
)
playback_status_iter = two_devices.protocols[0].monitor_playback_status()
for playback_status in avrcp.PlayStatus:
# Interim
two_devices.protocols[1].delegate.playback_status = avrcp.PlayStatus.STOPPED
assert (await anext(playback_status_iter)) == avrcp.PlayStatus.STOPPED
# Changed
two_devices.protocols[1].notify_playback_status_changed(playback_status)
assert (await anext(playback_status_iter)) == playback_status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_now_playing_content():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.NOW_PLAYING_CONTENT_CHANGED]
)
now_playing_iter = two_devices.protocols[0].monitor_now_playing_content()
for _ in range(2):
# Interim
await anext(now_playing_iter)
# Changed
two_devices.protocols[1].notify_now_playing_content_changed()
await anext(now_playing_iter)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_frame_parser() test_frame_parser()

View File

@@ -218,9 +218,9 @@ def test_return_parameters() -> None:
assert isinstance(params.status, utils.OpenIntEnum) assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters( params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('3C001122334455') bytes.fromhex('00001122334455')
) )
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.status, utils.OpenIntEnum) assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address) assert isinstance(params.bd_addr, hci.Address)
@@ -232,6 +232,14 @@ def test_return_parameters() -> None:
assert len(params.local_name) == 248 assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello' assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# Some return parameters may be shorter than the full length
# (for Command Complete events with errors)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('010011223344')
)
assert isinstance(params, hci.HCI_StatusReturnParameters)
assert params.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command(): def test_HCI_Command():

View File

@@ -26,9 +26,14 @@ from bumble.controller import Controller
from bumble.hci import ( from bumble.hci import (
HCI_AclDataPacket, HCI_AclDataPacket,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_CommandStatus,
HCI_Disconnect_Command,
HCI_Error, HCI_Error,
HCI_ErrorCode, HCI_ErrorCode,
HCI_Event, HCI_Event,
HCI_GenericReturnParameters,
HCI_LE_Terminate_BIG_Command,
HCI_Reset_Command, HCI_Reset_Command,
HCI_StatusReturnParameters, HCI_StatusReturnParameters,
) )
@@ -195,6 +200,7 @@ async def test_send_sync_command() -> None:
) )
host = Host(source, sink) host = Host(source, sink)
host.ready = True
# Sync command with success # Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command()) response1 = await host.send_sync_command(HCI_Reset_Command())
@@ -212,6 +218,61 @@ async def test_send_sync_command() -> None:
assert excinfo.value.error_code == error_response.return_parameters.status assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with error status should not raise when `check_status` is False # Sync command with raw result
response2 = await host.send_sync_command(HCI_Reset_Command(), check_status=False) response2 = await host.send_sync_command_raw(HCI_Reset_Command())
assert response2.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR assert response2.return_parameters.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR
# Sync command with a command that's not an HCI_SyncCommand
# (here, for convenience, we use an HCI_AsyncCommand instance)
command = HCI_Disconnect_Command(connection_handle=0x1234, reason=0x13)
sink.response = HCI_Command_Complete_Event(
1,
command.op_code,
HCI_GenericReturnParameters(data=bytes.fromhex("00112233")),
)
response3 = await host.send_sync_command_raw(command) # type: ignore
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Status_Event(
HCI_CommandStatus.PENDING,
1,
HCI_Reset_Command.op_code,
),
)
host = Host(source, sink)
host.ready = True
# Normal pending status
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0)
)
assert response == HCI_CommandStatus.PENDING
# Unknown HCI command result returned as a Command Status
sink.response = HCI_Command_Status_Event(
HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR,
1,
HCI_LE_Terminate_BIG_Command.op_code,
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# Unknown HCI command result returned as a Command Complete
sink.response = HCI_Command_Complete_Event(
1,
HCI_LE_Terminate_BIG_Command.op_code,
HCI_StatusReturnParameters(HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR),
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR