Compare commits

..

25 Commits

Author SHA1 Message Date
zxzxwu
a275c399a3 Merge pull request #734 from khsiao-google/le_subrating
Support LE Subrating
2025-08-07 16:52:17 +08:00
zxzxwu
c98275f385 Merge pull request #743 from zxzxwu/ascs
ASCS: Handle when CIS link is established before enable
2025-08-06 12:18:52 +08:00
khsiao-google
0b19347bef Only reset subrate_factor and continuation_number when connection interval changes 2025-08-06 03:55:41 +00:00
Josh Wu
f61fd64c0b ASCS: Handle when CIS link is established before enable 2025-08-05 17:31:42 +08:00
khsiao-google
ec12771be6 Support HCI_LE_Set_Host_Feature_Command 2025-08-05 05:56:00 +00:00
Gilles Boccon-Gibod
5b33e715da Merge pull request #742 from barbibulle/gbg/enable-manual-workflow-run 2025-08-04 20:57:23 -07:00
Gilles Boccon-Gibod
b885f29318 Merge pull request #740 from barbibulle/gbg/fix-735 2025-08-04 20:57:04 -07:00
Gilles Boccon-Gibod
7ca13188d5 Merge pull request #741 from barbibulle/gbg/update-black 2025-08-04 20:56:40 -07:00
Gilles Boccon-Gibod
89586d5d18 enable manual workflow runs 2025-08-04 19:46:04 -07:00
Gilles Boccon-Gibod
381032ceb9 update to black 25.1 2025-08-04 19:32:52 -07:00
Gilles Boccon-Gibod
12ca1c01f0 Revert "update to black formatter 25.1"
This reverts commit c034297bc0.
2025-08-04 19:24:30 -07:00
Gilles Boccon-Gibod
a7111d0107 send public keys earlier 2025-08-04 19:18:12 -07:00
Gilles Boccon-Gibod
c034297bc0 update to black formatter 25.1 2025-08-02 21:11:34 -07:00
Gilles Boccon-Gibod
a1eff958e6 do not wait for display 2025-08-02 21:10:45 -07:00
khsiao-google
d6282a7247 Support LE Subrating reply to comments 2025-08-03 03:39:23 +00:00
Gilles Boccon-Gibod
efdc770fde Merge pull request #737 from leifdreizler/fix-spdx-license
Update license field to use proper SPDX identifier
2025-08-02 11:22:58 -07:00
Leif
357d7f9c22 Update pyproject.toml 2025-08-02 08:18:36 -04:00
Leif Dreizler
3bc08b4e0d Update license field to use proper SPDX identifier
This changes the license field to be a valid [SPDX identifier](https://spdx.org/licenses) aligning with [PEP 639](https://peps.python.org/pep-0639/#project-source-metadata). This populates the `license_expression` field in the PyPI API and is used by downstream tools including deps.dev

These changes were generated by Claude after reviewing the license and manifest files in your repository, but opened and reviewed by me. Please let me know if the analysis is incorrect and thanks for being an OSS maintainer.
2025-08-01 20:19:25 -04:00
khsiao-google
982aaeabc3 Support LE Subrating 2025-07-31 02:52:42 +00:00
Gilles Boccon-Gibod
1dc0950177 Merge pull request #730 from google/gbg/apple-media-service
basic AMS implementation
2025-07-29 22:34:25 -07:00
zxzxwu
df0fd74533 Merge pull request #733 from zxzxwu/l2cap
Fix L2CAP_Control_Frame errors
2025-07-30 13:12:44 +08:00
Josh Wu
822f97fa84 Fix L2CAP errors 2025-07-30 12:00:20 +08:00
Gilles Boccon-Gibod
4a6b0ef840 Merge pull request #732 from google/gbg/722
fix #722
2025-07-29 10:50:02 -07:00
Gilles Boccon-Gibod
bf8a2cdcb5 add discrete command methods 2025-07-26 20:24:55 -07:00
Gilles Boccon-Gibod
4bf7448a01 basic AMS implementation 2025-07-22 14:57:52 -07:00
22 changed files with 1063 additions and 87 deletions

View File

@@ -6,6 +6,8 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
branches: [main]
permissions:
contents: read

View File

@@ -17,6 +17,8 @@ on:
pull_request:
# The branches below must be a subset of the branches above
branches: [ main ]
workflow_dispatch:
branches: [main]
schedule:
- cron: '39 21 * * 4'

View File

@@ -7,6 +7,10 @@ on:
branches: [ main ]
paths:
- 'extras/android/BtBench/**'
workflow_dispatch:
branches: [main]
paths:
- 'extras/android/BtBench/**'
permissions:
contents: read

View File

@@ -5,6 +5,8 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
branches: [main]
permissions:
contents: read

View File

@@ -6,6 +6,8 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
branches: [main]
permissions:
contents: read

View File

@@ -1117,7 +1117,7 @@ class Protocol(utils.EventEmitter):
@staticmethod
def _check_vendor_dependent_frame(
frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame]
frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame],
) -> bool:
if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID:
logger.debug("unsupported company id, ignoring")

View File

@@ -1269,6 +1269,56 @@ class Controller:
)
return bytes([HCI_SUCCESS]) + bd_addr
def on_hci_le_set_default_subrate_command(
self, command: hci.HCI_LE_Set_Default_Subrate_Command
):
'''
See Bluetooth spec Vol 6, Part E - 7.8.123 LE Set Event Mask Command
'''
if (
command.subrate_max * (command.max_latency) > 500
or command.subrate_max < command.subrate_min
or command.continuation_number >= command.subrate_max
):
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
return bytes([HCI_SUCCESS])
def on_hci_le_subrate_request_command(
self, command: hci.HCI_LE_Subrate_Request_Command
):
'''
See Bluetooth spec Vol 6, Part E - 7.8.124 LE Subrate Request command
'''
if (
command.subrate_max * (command.max_latency) > 500
or command.continuation_number < command.continuation_number
or command.subrate_max < command.subrate_min
or command.continuation_number >= command.subrate_max
):
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
self.send_hci_packet(
hci.HCI_Command_Status_Event(
status=hci.HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.send_hci_packet(
hci.HCI_LE_Subrate_Change_Event(
status=hci.HCI_SUCCESS,
connection_handle=command.connection_handle,
subrate_factor=2,
peripheral_latency=2,
continuation_number=command.continuation_number,
supervision_timeout=command.supervision_timeout,
)
)
return None
def on_hci_le_set_event_mask_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.1 LE Set Event Mask Command
@@ -1815,3 +1865,11 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_set_host_feature_command(
self, _command: hci.HCI_LE_Set_Host_Feature_Command
):
'''
See Bluetooth spec Vol 4, Part E - 7.8.115 LE Set Host Feature command
'''
return bytes([HCI_SUCCESS])

View File

@@ -1752,6 +1752,8 @@ class Connection(utils.CompositeEventEmitter):
EVENT_CIS_REQUEST = "cis_request"
EVENT_CIS_ESTABLISHMENT = "cis_establishment"
EVENT_CIS_ESTABLISHMENT_FAILURE = "cis_establishment_failure"
EVENT_LE_SUBRATE_CHANGE = "le_subrate_change"
EVENT_LE_SUBRATE_CHANGE_FAILURE = "le_subrate_change_failure"
@utils.composite_listener
class Listener:
@@ -1787,6 +1789,12 @@ class Connection(utils.CompositeEventEmitter):
connection_interval: float # Connection interval, in milliseconds. [LE only]
peripheral_latency: int # Peripheral latency, in number of intervals. [LE only]
supervision_timeout: float # Supervision timeout, in milliseconds.
subrate_factor: int = (
1 # See Bluetooth spec Vol 6, Part B - 4.5.1 Connection events
)
continuation_number: int = (
0 # See Bluetooth spec Vol 6, Part B - 4.5.1 Connection events
)
def __init__(
self,
@@ -2058,6 +2066,7 @@ class DeviceConfiguration:
le_simultaneous_enabled: bool = False
le_privacy_enabled: bool = False
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
le_subrate_enabled: bool = False
classic_enabled: bool = False
classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True
@@ -2410,6 +2419,7 @@ class Device(utils.CompositeEventEmitter):
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
self.le_subrate_enabled = config.le_subrate_enabled
self.classic_enabled = config.classic_enabled
self.cis_enabled = config.cis_enabled
self.classic_sc_enabled = config.classic_sc_enabled
@@ -2789,6 +2799,15 @@ class Device(utils.CompositeEventEmitter):
check_result=True,
)
if self.le_subrate_enabled:
await self.send_command(
hci.HCI_LE_Set_Host_Feature_Command(
bit_number=hci.LeFeature.CONNECTION_SUBRATING_HOST_SUPPORT,
bit_value=1,
),
check_result=True,
)
if self.config.channel_sounding_enabled:
await self.send_command(
hci.HCI_LE_Set_Host_Feature_Command(
@@ -6189,11 +6208,23 @@ class Device(utils.CompositeEventEmitter):
f'{connection.peer_address} as {connection.role_name}, '
f'{connection_parameters}'
)
connection.parameters = Connection.Parameters(
connection_parameters.connection_interval * 1.25,
connection_parameters.peripheral_latency,
connection_parameters.supervision_timeout * 10.0,
)
if (
connection.parameters.connection_interval
!= connection_parameters.connection_interval * 1.25
):
connection.parameters = Connection.Parameters(
connection_parameters.connection_interval * 1.25,
connection_parameters.peripheral_latency,
connection_parameters.supervision_timeout * 10.0,
)
else:
connection.parameters = Connection.Parameters(
connection_parameters.connection_interval * 1.25,
connection_parameters.peripheral_latency,
connection_parameters.supervision_timeout * 10.0,
connection.parameters.subrate_factor,
connection.parameters.continuation_number,
)
connection.emit(connection.EVENT_CONNECTION_PARAMETERS_UPDATE)
@host_event_handler
@@ -6226,6 +6257,25 @@ class Device(utils.CompositeEventEmitter):
)
connection.emit(connection.EVENT_CONNECTION_PHY_UPDATE_FAILURE, error)
@host_event_handler
@with_connection_from_handle
def on_le_subrate_change(
self,
connection: Connection,
subrate_factor: int,
peripheral_latency: int,
continuation_number: int,
supervision_timeout: int,
):
connection.parameters = Connection.Parameters(
connection.parameters.connection_interval,
peripheral_latency,
supervision_timeout * 10.0,
subrate_factor,
continuation_number,
)
connection.emit(connection.EVENT_LE_SUBRATE_CHANGE)
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection, att_mtu):

View File

@@ -5315,6 +5315,37 @@ class HCI_LE_Set_Host_Feature_Command(HCI_Command):
bit_value: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Set_Default_Subrate_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.123 LE Set Default Subrate command
'''
subrate_min: int = field(metadata=metadata(2))
subrate_max: int = field(metadata=metadata(2))
max_latency: int = field(metadata=metadata(2))
continuation_number: int = field(metadata=metadata(2))
supervision_timeout: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Subrate_Request_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.124 LE Subrate Request command
'''
connection_handle: int = field(metadata=metadata(2))
subrate_min: int = field(metadata=metadata(2))
subrate_max: int = field(metadata=metadata(2))
max_latency: int = field(metadata=metadata(2))
continuation_number: int = field(metadata=metadata(2))
supervision_timeout: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -6460,6 +6491,22 @@ class HCI_LE_BIGInfo_Advertising_Report_Event(HCI_LE_Meta_Event):
encryption: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event
@dataclasses.dataclass
class HCI_LE_Subrate_Change_Event(HCI_LE_Meta_Event):
'''
See Bluetooth spec @ 7.7.65.35 LE Subrate Change event
'''
status: int = field(metadata=metadata(STATUS_SPEC))
connection_handle: int = field(metadata=metadata(2))
subrate_factor: int = field(metadata=metadata(2))
peripheral_latency: int = field(metadata=metadata(2))
continuation_number: int = field(metadata=metadata(2))
supervision_timeout: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event
@dataclasses.dataclass

View File

@@ -1645,5 +1645,15 @@ class Host(utils.EventEmitter):
def on_hci_le_cs_subevent_result_continue_event(self, event):
self.emit('cs_subevent_result_continue', event)
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
self.emit(
'le_subrate_change',
event.connection_handle,
event.subrate_factor,
event.peripheral_latency,
event.continuation_number,
event.supervision_timeout,
)
def on_hci_vendor_event(self, event):
self.emit('vendor_event', event)

View File

@@ -213,7 +213,7 @@ class L2CAP_Control_Frame:
fields: ClassVar[hci.Fields] = ()
code: int = dataclasses.field(default=0, init=False)
name: str = dataclasses.field(default='', init=False)
_data: Optional[bytes] = dataclasses.field(default=None, init=False)
_payload: Optional[bytes] = dataclasses.field(default=None, init=False)
identifier: int
@@ -223,7 +223,8 @@ class L2CAP_Control_Frame:
subclass = L2CAP_Control_Frame.classes.get(code)
if subclass is None:
instance = L2CAP_Control_Frame(pdu)
instance = L2CAP_Control_Frame(identifier=identifier)
instance.payload = pdu[4:]
instance.code = CommandCode(code)
instance.name = instance.code.name
return instance
@@ -232,11 +233,11 @@ class L2CAP_Control_Frame:
identifier=identifier,
)
frame.identifier = identifier
frame.data = pdu[4:]
if length != len(pdu):
frame.payload = pdu[4:]
if length != len(frame.payload):
logger.warning(
color(
f'!!! length mismatch: expected {len(pdu) - 4} but got {length}',
f'!!! length mismatch: expected {length} but got {len(frame.payload)}',
'red',
)
)
@@ -273,34 +274,20 @@ class L2CAP_Control_Frame:
return subclass
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
self.identifier = kwargs.get('identifier', 0)
if self.fields:
if kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = hci.HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
self.data = pdu[4:] if pdu else b''
@property
def data(self) -> bytes:
if self._data is None:
self._data = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._data
def payload(self) -> bytes:
if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@data.setter
def data(self, parameters: bytes) -> None:
self._data = parameters
@payload.setter
def payload(self, payload: bytes) -> None:
self._payload = payload
def __bytes__(self) -> bytes:
return (
struct.pack('<BBH', self.code, self.identifier, len(self.data) + 4)
+ self.data
struct.pack('<BBH', self.code, self.identifier, len(self.payload))
+ self.payload
)
def __str__(self) -> str:
@@ -308,8 +295,8 @@ class L2CAP_Control_Frame:
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.data) > 1:
result += f': {self.data.hex()}'
if len(self.payload) > 1:
result += f': {self.payload.hex()}'
return result

View File

@@ -49,7 +49,7 @@ _SERVICERS_HOOKS: list[Callable[[PandoraDevice, Config, grpc.aio.Server], None]]
def register_servicer_hook(
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None],
) -> None:
_SERVICERS_HOOKS.append(hook)

404
bumble/profiles/ams.py Normal file
View File

@@ -0,0 +1,404 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apple Media Service (AMS).
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import logging
from typing import Optional, Iterable, Union
from bumble.device import Peer
from bumble.gatt import (
Characteristic,
GATT_AMS_SERVICE,
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
TemplateService,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
from bumble import utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol
# -----------------------------------------------------------------------------
class RemoteCommandId(utils.OpenIntEnum):
PLAY = 0
PAUSE = 1
TOGGLE_PLAY_PAUSE = 2
NEXT_TRACK = 3
PREVIOUS_TRACK = 4
VOLUME_UP = 5
VOLUME_DOWN = 6
ADVANCE_REPEAT_MODE = 7
ADVANCE_SHUFFLE_MODE = 8
SKIP_FORWARD = 9
SKIP_BACKWARD = 10
LIKE_TRACK = 11
DISLIKE_TRACK = 12
BOOKMARK_TRACK = 13
class EntityId(utils.OpenIntEnum):
PLAYER = 0
QUEUE = 1
TRACK = 2
class ActionId(utils.OpenIntEnum):
POSITIVE = 0
NEGATIVE = 1
class EntityUpdateFlags(enum.IntFlag):
TRUNCATED = 1
class PlayerAttributeId(utils.OpenIntEnum):
NAME = 0
PLAYBACK_INFO = 1
VOLUME = 2
class QueueAttributeId(utils.OpenIntEnum):
INDEX = 0
COUNT = 1
SHUFFLE_MODE = 2
REPEAT_MODE = 3
class ShuffleMode(utils.OpenIntEnum):
OFF = 0
ONE = 1
ALL = 2
class RepeatMode(utils.OpenIntEnum):
OFF = 0
ONE = 1
ALL = 2
class TrackAttributeId(utils.OpenIntEnum):
ARTIST = 0
ALBUM = 1
TITLE = 2
DURATION = 3
class PlaybackState(utils.OpenIntEnum):
PAUSED = 0
PLAYING = 1
REWINDING = 2
FAST_FORWARDING = 3
@dataclasses.dataclass
class PlaybackInfo:
playback_state: PlaybackState = PlaybackState.PAUSED
playback_rate: float = 1.0
elapsed_time: float = 0.0
# -----------------------------------------------------------------------------
# GATT Server-side
# -----------------------------------------------------------------------------
class Ams(TemplateService):
UUID = GATT_AMS_SERVICE
remote_command_characteristic: Characteristic
entity_update_characteristic: Characteristic
entity_attribute_characteristic: Characteristic
def __init__(self) -> None:
# TODO not the final implementation
self.remote_command_characteristic = Characteristic(
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.Permissions.WRITEABLE,
)
# TODO not the final implementation
self.entity_update_characteristic = Characteristic(
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
Characteristic.Properties.NOTIFY | Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
)
# TODO not the final implementation
self.entity_attribute_characteristic = Characteristic(
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
Characteristic.Properties.READ
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.Permissions.WRITEABLE | Characteristic.Permissions.READABLE,
)
super().__init__(
[
self.remote_command_characteristic,
self.entity_update_characteristic,
self.entity_attribute_characteristic,
]
)
# -----------------------------------------------------------------------------
# GATT Client-side
# -----------------------------------------------------------------------------
class AmsProxy(ProfileServiceProxy):
SERVICE_CLASS = Ams
# NOTE: these don't use adapters, because the format for write and notifications
# are different.
remote_command: CharacteristicProxy[bytes]
entity_update: CharacteristicProxy[bytes]
entity_attribute: CharacteristicProxy[bytes]
def __init__(self, service_proxy: ServiceProxy):
self.remote_command = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC
)
self.entity_update = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC
)
self.entity_attribute = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC
)
class AmsClient(utils.EventEmitter):
EVENT_SUPPORTED_COMMANDS = "supported_commands"
EVENT_PLAYER_NAME = "player_name"
EVENT_PLAYER_PLAYBACK_INFO = "player_playback_info"
EVENT_PLAYER_VOLUME = "player_volume"
EVENT_QUEUE_COUNT = "queue_count"
EVENT_QUEUE_INDEX = "queue_index"
EVENT_QUEUE_SHUFFLE_MODE = "queue_shuffle_mode"
EVENT_QUEUE_REPEAT_MODE = "queue_repeat_mode"
EVENT_TRACK_ARTIST = "track_artist"
EVENT_TRACK_ALBUM = "track_album"
EVENT_TRACK_TITLE = "track_title"
EVENT_TRACK_DURATION = "track_duration"
supported_commands: set[RemoteCommandId]
player_name: str = ""
player_playback_info: PlaybackInfo = PlaybackInfo(PlaybackState.PAUSED, 0.0, 0.0)
player_volume: float = 1.0
queue_count: int = 0
queue_index: int = 0
queue_shuffle_mode: ShuffleMode = ShuffleMode.OFF
queue_repeat_mode: RepeatMode = RepeatMode.OFF
track_artist: str = ""
track_album: str = ""
track_title: str = ""
track_duration: float = 0.0
def __init__(self, ams_proxy: AmsProxy) -> None:
super().__init__()
self._ams_proxy = ams_proxy
self._started = False
self._read_attribute_semaphore = asyncio.Semaphore()
self.supported_commands = set()
@classmethod
async def for_peer(cls, peer: Peer) -> Optional[AmsClient]:
ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy)
if ams_proxy is None:
return None
return cls(ams_proxy)
async def start(self) -> None:
logger.debug("subscribing to remote command characteristic")
await self._ams_proxy.remote_command.subscribe(
self._on_remote_command_notification
)
logger.debug("subscribing to entity update characteristic")
await self._ams_proxy.entity_update.subscribe(
lambda data: utils.AsyncRunner.spawn(
self._on_entity_update_notification(data)
)
)
self._started = True
async def stop(self) -> None:
await self._ams_proxy.remote_command.unsubscribe(
self._on_remote_command_notification
)
await self._ams_proxy.entity_update.unsubscribe(
self._on_entity_update_notification
)
self._started = False
async def observe(
self,
entity: EntityId,
attributes: Iterable[
Union[PlayerAttributeId, QueueAttributeId, TrackAttributeId]
],
) -> None:
await self._ams_proxy.entity_update.write_value(
bytes([entity] + list(attributes)), with_response=True
)
async def command(self, command: RemoteCommandId) -> None:
await self._ams_proxy.remote_command.write_value(
bytes([command]), with_response=True
)
async def play(self) -> None:
await self.command(RemoteCommandId.PLAY)
async def pause(self) -> None:
await self.command(RemoteCommandId.PAUSE)
async def toggle_play_pause(self) -> None:
await self.command(RemoteCommandId.TOGGLE_PLAY_PAUSE)
async def next_track(self) -> None:
await self.command(RemoteCommandId.NEXT_TRACK)
async def previous_track(self) -> None:
await self.command(RemoteCommandId.PREVIOUS_TRACK)
async def volume_up(self) -> None:
await self.command(RemoteCommandId.VOLUME_UP)
async def volume_down(self) -> None:
await self.command(RemoteCommandId.VOLUME_DOWN)
async def advance_repeat_mode(self) -> None:
await self.command(RemoteCommandId.ADVANCE_REPEAT_MODE)
async def advance_shuffle_mode(self) -> None:
await self.command(RemoteCommandId.ADVANCE_SHUFFLE_MODE)
async def skip_forward(self) -> None:
await self.command(RemoteCommandId.SKIP_FORWARD)
async def skip_backward(self) -> None:
await self.command(RemoteCommandId.SKIP_BACKWARD)
async def like_track(self) -> None:
await self.command(RemoteCommandId.LIKE_TRACK)
async def dislike_track(self) -> None:
await self.command(RemoteCommandId.DISLIKE_TRACK)
async def bookmark_track(self) -> None:
await self.command(RemoteCommandId.BOOKMARK_TRACK)
def _on_remote_command_notification(self, data: bytes) -> None:
supported_commands = [RemoteCommandId(command) for command in data]
logger.debug(
f"supported commands: {[command.name for command in supported_commands]}"
)
for command in supported_commands:
self.supported_commands.add(command)
self.emit(self.EVENT_SUPPORTED_COMMANDS)
async def _on_entity_update_notification(self, data: bytes) -> None:
entity = EntityId(data[0])
flags = EntityUpdateFlags(data[2])
value = data[3:]
if flags & EntityUpdateFlags.TRUNCATED:
logger.debug("truncated attribute, fetching full value")
# Write the entity and attribute we're interested in
# (protected by a semaphore, so that we only read one attribute at a time)
async with self._read_attribute_semaphore:
await self._ams_proxy.entity_attribute.write_value(
data[:2], with_response=True
)
value = await self._ams_proxy.entity_attribute.read_value()
if entity == EntityId.PLAYER:
player_attribute = PlayerAttributeId(data[1])
if player_attribute == PlayerAttributeId.NAME:
self.player_name = value.decode()
self.emit(self.EVENT_PLAYER_NAME)
elif player_attribute == PlayerAttributeId.PLAYBACK_INFO:
playback_state_str, playback_rate_str, elapsed_time_str = (
value.decode().split(",")
)
self.player_playback_info = PlaybackInfo(
PlaybackState(int(playback_state_str)),
float(playback_rate_str),
float(elapsed_time_str),
)
self.emit(self.EVENT_PLAYER_PLAYBACK_INFO)
elif player_attribute == PlayerAttributeId.VOLUME:
self.player_volume = float(value.decode())
self.emit(self.EVENT_PLAYER_VOLUME)
else:
logger.warning(f"received unknown player attribute {player_attribute}")
elif entity == EntityId.QUEUE:
queue_attribute = QueueAttributeId(data[1])
if queue_attribute == QueueAttributeId.COUNT:
self.queue_count = int(value)
self.emit(self.EVENT_QUEUE_COUNT)
elif queue_attribute == QueueAttributeId.INDEX:
self.queue_index = int(value)
self.emit(self.EVENT_QUEUE_INDEX)
elif queue_attribute == QueueAttributeId.REPEAT_MODE:
self.queue_repeat_mode = RepeatMode(int(value))
self.emit(self.EVENT_QUEUE_REPEAT_MODE)
elif queue_attribute == QueueAttributeId.SHUFFLE_MODE:
self.queue_shuffle_mode = ShuffleMode(int(value))
self.emit(self.EVENT_QUEUE_SHUFFLE_MODE)
else:
logger.warning(f"received unknown queue attribute {queue_attribute}")
elif entity == EntityId.TRACK:
track_attribute = TrackAttributeId(data[1])
if track_attribute == TrackAttributeId.ARTIST:
self.track_artist = value.decode()
self.emit(self.EVENT_TRACK_ARTIST)
elif track_attribute == TrackAttributeId.ALBUM:
self.track_album = value.decode()
self.emit(self.EVENT_TRACK_ALBUM)
elif track_attribute == TrackAttributeId.TITLE:
self.track_title = value.decode()
self.emit(self.EVENT_TRACK_TITLE)
elif track_attribute == TrackAttributeId.DURATION:
self.track_duration = float(value.decode())
self.emit(self.EVENT_TRACK_DURATION)
else:
logger.warning(f"received unknown track attribute {track_attribute}")
else:
logger.warning(f"received unknown attribute ID {data[1]}")

View File

@@ -452,6 +452,16 @@ class AseStateMachine(gatt.Characteristic):
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
# CIS could be established before enable.
if cis_link := next(
(
cis_link
for cis_link in self.service.device.cis_links.values()
if cis_link.cig_id == self.cig_id and cis_link.cis_id == self.cis_id
),
None,
):
self.on_cis_establishment(cis_link)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)

View File

@@ -946,7 +946,9 @@ class Session:
self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
await self.pairing_config.delegate.display_number(self.passkey, digits=6)
self.connection.cancel_on_disconnection(
self.pairing_config.delegate.display_number(self.passkey, digits=6)
)
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
@@ -1569,11 +1571,12 @@ class Session:
if self.pairing_method == PairingMethod.CTKD_OVER_CLASSIC:
# Authentication is already done in SMP, so remote shall start keys distribution immediately
return
elif self.sc:
if self.sc:
self.send_public_key_command()
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey(self.send_pairing_confirm_command)
@@ -1846,10 +1849,10 @@ class Session:
elif self.pairing_method == PairingMethod.PASSKEY:
self.send_pairing_confirm_command()
else:
# Send our public key back to the initiator
self.send_public_key_command()
def next_steps() -> None:
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (
PairingMethod.JUST_WORKS,

220
examples/run_ams_client.py Normal file
View File

@@ -0,0 +1,220 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import sys
import os
import logging
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport
from bumble.profiles.ams import (
AmsClient,
EntityId,
PlayerAttributeId,
QueueAttributeId,
TrackAttributeId,
RemoteCommandId,
)
# -----------------------------------------------------------------------------
async def handle_command_client(
ams_client: AmsClient, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
while True:
command = (await reader.readline()).decode("utf-8")
if not command.endswith("\n"):
print("command client terminated")
return
command = command.strip()
try:
if command.upper() in [member.name for member in RemoteCommandId]:
await ams_client.command(RemoteCommandId[command.upper()])
continue
except Exception as error:
writer.write(f"ERROR: {error}\n".encode("utf-8"))
writer.write(f"unknown command {command}\n".encode("utf-8"))
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_ams_client.py <device-config> <transport-spec> '
'<bluetooth-address> <mtu>'
)
print('example: run_ams_client.py device1.json usb:0 E1:CA:72:48:C4:E8 512')
return
device_config, transport_spec, bluetooth_address, mtu = sys.argv[1:]
print('<<< connecting to HCI...')
async with await open_transport(transport_spec) as hci_transport:
print('<<< connected')
# Create a device to manage the host, with a custom listener
device = Device.from_config_file_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
await device.power_on()
# Connect to the peer
print(f'=== Connecting to {bluetooth_address}...')
connection = await device.connect(bluetooth_address)
print(f'=== Connected: {connection}')
await connection.encrypt()
peer = Peer(connection)
mtu_int = int(mtu)
if mtu_int:
new_mtu = await peer.request_mtu(mtu_int)
print(f'ATT MTU = {new_mtu}')
ams_client = await AmsClient.for_peer(peer)
if ams_client is None:
print("!!! no AMS service found")
return
# Register event handlers
def on_supported_commands():
print(
color("Supported commands:", "magenta"),
", ".join([command.name for command in ams_client.supported_commands]),
)
ams_client.on(AmsClient.EVENT_SUPPORTED_COMMANDS, on_supported_commands)
def on_player_name():
print(color("Player Name:", "green"), ams_client.player_name)
ams_client.on(AmsClient.EVENT_PLAYER_NAME, on_player_name)
def on_player_playback_info():
print(
color("Playback State:", "green"),
ams_client.player_playback_info.playback_state.name,
)
print(
color("Playback Rate: ", "green"),
ams_client.player_playback_info.playback_rate,
)
print(
color("Elapsed Time: ", "green"),
ams_client.player_playback_info.elapsed_time,
)
ams_client.on(AmsClient.EVENT_PLAYER_PLAYBACK_INFO, on_player_playback_info)
def on_player_volume():
print(color("Volume:", "green"), ams_client.player_volume)
ams_client.on(AmsClient.EVENT_PLAYER_VOLUME, on_player_volume)
def on_queue_count():
print(color("Queue Count:", "yellow"), ams_client.queue_count)
ams_client.on(AmsClient.EVENT_QUEUE_COUNT, on_queue_count)
def on_queue_index():
print(color("Queue Index:", "yellow"), ams_client.queue_index)
ams_client.on(AmsClient.EVENT_QUEUE_INDEX, on_queue_index)
def on_queue_shuffle_mode():
print(
color("Queue Shuffle Mode:", "yellow"),
ams_client.queue_shuffle_mode.name,
)
ams_client.on(AmsClient.EVENT_QUEUE_SHUFFLE_MODE, on_queue_shuffle_mode)
def on_queue_repeat_mode():
print(
color("Queue Repeat Mode:", "yellow"), ams_client.queue_repeat_mode.name
)
ams_client.on(AmsClient.EVENT_QUEUE_REPEAT_MODE, on_queue_repeat_mode)
def on_track_artist():
print(color("Track Artist:", "cyan"), ams_client.track_artist)
ams_client.on(AmsClient.EVENT_TRACK_ARTIST, on_track_artist)
def on_track_album():
print(color("Track Album:", "cyan"), ams_client.track_album)
ams_client.on(AmsClient.EVENT_TRACK_ALBUM, on_track_album)
def on_track_title():
print(color("Track Title:", "cyan"), ams_client.track_title)
ams_client.on(AmsClient.EVENT_TRACK_TITLE, on_track_title)
def on_track_duration():
print(color("Track Duration:", "cyan"), ams_client.track_duration)
ams_client.on(AmsClient.EVENT_TRACK_DURATION, on_track_duration)
# Start the client
await ams_client.start()
# Observe the player, queue and track
await ams_client.observe(
EntityId.PLAYER,
[
PlayerAttributeId.NAME,
PlayerAttributeId.PLAYBACK_INFO,
PlayerAttributeId.VOLUME,
],
)
await ams_client.observe(
EntityId.QUEUE,
[
QueueAttributeId.COUNT,
QueueAttributeId.INDEX,
QueueAttributeId.REPEAT_MODE,
QueueAttributeId.SHUFFLE_MODE,
],
)
await ams_client.observe(
EntityId.TRACK,
[
TrackAttributeId.ALBUM,
TrackAttributeId.ARTIST,
TrackAttributeId.DURATION,
TrackAttributeId.TITLE,
],
)
# Accept a TCP connection to handle commands.
tcp_server = await asyncio.start_server(
lambda reader, writer: handle_command_client(ams_client, reader, writer),
'127.0.0.1',
9000,
)
print("Accepting command client on port 9000")
async with tcp_server:
await tcp_server.serve_forever()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(main())

View File

@@ -7,6 +7,8 @@ name = "bumble"
dynamic = ["version"]
description = "Bluetooth Stack for Apps, Emulation, Test and Experimentation"
readme = "README.md"
license = "Apache-2.0"
license-files = ["LICENSE"]
authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.9"
dependencies = [
@@ -42,7 +44,7 @@ test = [
"coverage >= 6.4",
]
development = [
"black == 24.3",
"black ~= 25.1",
"bt-test-interfaces >= 0.0.6",
"grpcio-tools >= 1.62.1",
"invoke >= 1.7.3",

View File

@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import functools
import pytest
import logging
@@ -55,7 +54,7 @@ from bumble.profiles.pacs import (
PublishedAudioCapabilitiesServiceProxy,
)
from bumble.profiles.le_audio import Metadata
from tests.test_utils import TwoDevices
from tests.test_utils import TwoDevices, async_barrier
# -----------------------------------------------------------------------------
@@ -441,15 +440,114 @@ async def test_ascs():
assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
await asyncio.sleep(0.001)
await async_barrier()
# -----------------------------------------------------------------------------
async def run():
await test_pacs()
@pytest.mark.asyncio
async def test_ascs_enable_source_then_sink():
devices = TwoDevices()
ascs_server = AudioStreamControlService(
device=devices[1], sink_ase_id=[1], source_ase_id=[2]
)
sink_ase = ascs_server.ase_state_machines[1]
source_ase = ascs_server.ase_state_machines[2]
devices[1].add_service(ascs_server)
condition = asyncio.Condition()
async def on_state_change():
async with condition:
condition.notify_all()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
sink_ase.on(sink_ase.EVENT_STATE_CHANGE, on_state_change)
source_ase.on(sink_ase.EVENT_STATE_CHANGE, on_state_change)
await devices.setup_connection()
peer = device.Peer(devices.connections[0])
ascs_client = await peer.discover_service_and_create_proxy(
AudioStreamControlServiceProxy
)
# Config Codec
config = CodecSpecificConfiguration(
sampling_frequency=SamplingFrequency.FREQ_48000,
frame_duration=FrameDuration.DURATION_10000_US,
audio_channel_allocation=AudioLocation.FRONT_LEFT,
octets_per_codec_frame=120,
codec_frames_per_sdu=1,
)
await ascs_client.ase_control_point.write_value(
ASE_Config_Codec(
ase_id=[1, 2],
target_latency=[3, 4],
target_phy=[5, 6],
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
codec_specific_configuration=[config, config],
)
)
async with condition:
await condition.wait_for(
lambda: (
sink_ase.state == AseStateMachine.State.CODEC_CONFIGURED
and source_ase.state == AseStateMachine.State.CODEC_CONFIGURED
)
)
# Config QOS
await ascs_client.ase_control_point.write_value(
ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 1],
cis_id=[1, 1],
sdu_interval=[100, 100],
framing=[0, 0],
phy=[1, 1],
max_sdu=[100, 100],
retransmission_number=[16, 16],
max_transport_latency=[150, 150],
presentation_delay=[10, 10],
)
)
async with condition:
await condition.wait_for(
lambda: (
sink_ase.state == AseStateMachine.State.QOS_CONFIGURED
and source_ase.state == AseStateMachine.State.QOS_CONFIGURED
)
)
# Enable ASE 2
await ascs_client.ase_control_point.write_value(
ASE_Enable(ase_id=[2], metadata=[b'foo'])
)
await async_barrier()
cis_handles = await devices[0].setup_cig(
device.CigParameters(
cig_id=1,
cis_parameters=[device.CigParameters.CisParameters(cis_id=1)],
sdu_interval_c_to_p=100,
sdu_interval_p_to_c=100,
)
)
await devices[0].create_cis([(cis_handles[0], devices.connections[0])])
async with condition:
await condition.wait_for(
lambda: (source_ase.state == AseStateMachine.State.ENABLING)
)
await ascs_client.ase_control_point.write_value(
ASE_Receiver_Start_Ready(ase_id=[2])
)
async with condition:
await condition.wait_for(
lambda: (source_ase.state == AseStateMachine.State.STREAMING)
)
# Enable ASE 1
await ascs_client.ase_control_point.write_value(
ASE_Enable(ase_id=[1], metadata=[b'bar'])
)
async with condition:
await condition.wait_for(
lambda: (sink_ase.state == AseStateMachine.State.STREAMING)
)

View File

@@ -611,6 +611,37 @@ async def test_enter_and_exit_sniff_mode():
assert devices.connections[0].classic_interval == 2
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_le_request_subrate():
devices = TwoDevices()
await devices.setup_connection()
q = asyncio.Queue()
def on_le_subrate_change():
q.put_nowait(lambda: None)
devices.connections[0].on(Connection.EVENT_LE_SUBRATE_CHANGE, on_le_subrate_change)
await devices[0].send_command(
hci.HCI_LE_Subrate_Request_Command(
connection_handle=devices.connections[0].handle,
subrate_min=2,
subrate_max=2,
max_latency=2,
continuation_number=1,
supervision_timeout=2,
)
)
await asyncio.wait_for(q.get(), _TIMEOUT)
assert devices.connections[0].parameters.subrate_factor == 2
assert devices.connections[0].parameters.peripheral_latency == 2
assert devices.connections[0].parameters.continuation_number == 1
assert devices.connections[0].parameters.supervision_timeout == 20
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():

View File

@@ -237,7 +237,7 @@ async def test_hf_indicator(hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtoco
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_codec_negotiation(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
@@ -281,7 +281,7 @@ async def test_answer(hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]):
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reject_incoming_call(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
@@ -307,7 +307,7 @@ async def test_terminate_call(hfp_connections: tuple[hfp.HfProtocol, hfp.AgProto
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_query_calls_without_calls(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
@@ -317,7 +317,7 @@ async def test_query_calls_without_calls(
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_query_calls_with_calls(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
ag.calls.append(
@@ -418,7 +418,7 @@ async def test_speaker_volume(hfp_connections: tuple[hfp.HfProtocol, hfp.AgProto
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_microphone_volume(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
microphone_volume_future = asyncio.get_running_loop().create_future()
@@ -448,7 +448,7 @@ async def test_cli_notification(hfp_connections: tuple[hfp.HfProtocol, hfp.AgPro
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_voice_recognition_from_hf(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
voice_recognition_future = asyncio.get_running_loop().create_future()
@@ -462,7 +462,7 @@ async def test_voice_recognition_from_hf(
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_voice_recognition_from_ag(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
voice_recognition_future = asyncio.get_running_loop().create_future()
@@ -572,7 +572,7 @@ async def test_sco_setup():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_hf_batched_response(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections
@@ -584,7 +584,7 @@ async def test_hf_batched_response(
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ag_batched_commands(
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol]
hfp_connections: tuple[hfp.HfProtocol, hfp.AgProtocol],
):
hf, ag = hfp_connections

View File

@@ -22,12 +22,8 @@ import random
import pytest
from bumble.core import ProtocolError
from bumble.l2cap import (
L2CAP_Connection_Request,
ClassicChannelSpec,
LeCreditBasedChannelSpec,
)
from .test_utils import TwoDevices
from bumble import l2cap
from .test_utils import TwoDevices, async_barrier
# -----------------------------------------------------------------------------
@@ -41,42 +37,53 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def test_helpers():
psm = L2CAP_Connection_Request.serialize_psm(0x01)
psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x01)
assert psm == bytes([0x01, 0x00])
psm = L2CAP_Connection_Request.serialize_psm(0x1023)
psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x1023)
assert psm == bytes([0x23, 0x10])
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x242311)
assert psm == bytes([0x11, 0x23, 0x24])
(offset, psm) = L2CAP_Connection_Request.parse_psm(
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x01, 0x00, 0x44]), 1
)
assert offset == 3
assert psm == 0x01
(offset, psm) = L2CAP_Connection_Request.parse_psm(
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x23, 0x10, 0x44]), 1
)
assert offset == 3
assert psm == 0x1023
(offset, psm) = L2CAP_Connection_Request.parse_psm(
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
)
assert offset == 4
assert psm == 0x242311
rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88)
rq = l2cap.L2CAP_Connection_Request(psm=0x01, source_cid=0x44, identifier=0x88)
brq = bytes(rq)
srq = L2CAP_Connection_Request.from_bytes(brq)
assert isinstance(srq, L2CAP_Connection_Request)
srq = l2cap.L2CAP_Connection_Request.from_bytes(brq)
assert isinstance(srq, l2cap.L2CAP_Connection_Request)
assert srq.psm == rq.psm
assert srq.source_cid == rq.source_cid
assert srq.identifier == rq.identifier
# -----------------------------------------------------------------------------
def test_unimplemented_control_frame():
frame = l2cap.L2CAP_Control_Frame(identifier=1)
frame.code = 0xFF
frame.payload = b'123456'
parsed = l2cap.L2CAP_Control_Frame.from_bytes(bytes(frame))
assert parsed.code == 0xFF
assert parsed.payload == b'123456'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
@@ -87,7 +94,7 @@ async def test_basic_connection():
# Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError):
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(psm)
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
# Now add a listener
@@ -104,10 +111,10 @@ async def test_basic_connection():
channel.sink = on_data
devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(psm=1234), handler=on_coc
spec=l2cap.LeCreditBasedChannelSpec(psm=1234), handler=on_coc
)
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(psm)
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000))
@@ -137,6 +144,41 @@ async def test_basic_connection():
assert sent_bytes == received_bytes
# -----------------------------------------------------------------------------
@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType))
async def test_l2cap_information_request(monkeypatch, info_type):
# TODO: Replace handlers with API when implemented
devices = await TwoDevices.create_with_connection()
# Register handlers
info_rsp = list[l2cap.L2CAP_Information_Response]()
def on_l2cap_information_response(connection, cid, frame):
info_rsp.append(frame)
assert (connection := devices.connections[0])
channel_manager = devices[0].l2cap_channel_manager
monkeypatch.setattr(
channel_manager,
'on_l2cap_information_response',
on_l2cap_information_response,
raising=False,
)
channel_manager.send_control_frame(
connection,
l2cap.L2CAP_LE_SIGNALING_CID,
l2cap.L2CAP_Information_Request(
identifier=channel_manager.next_identifier(connection),
info_type=info_type,
),
)
await async_barrier()
response = info_rsp[0]
assert response.result == l2cap.L2CAP_Information_Response.Result.SUCCESS
# -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps):
devices = TwoDevices()
@@ -151,11 +193,11 @@ async def transfer_payload(max_credits, mtu, mps):
channel.sink = on_data
server = devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=on_coc,
)
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(server.psm)
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
@@ -205,10 +247,10 @@ async def test_bidirectional_transfer():
client_received.append(data)
server = devices.devices[1].create_l2cap_server(
spec=LeCreditBasedChannelSpec(), handler=on_server_coc
spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=LeCreditBasedChannelSpec(server.psm)
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
client_channel.sink = on_client_data
@@ -242,10 +284,10 @@ async def test_mtu():
channel.on('open', lambda: on_channel_open(channel))
server = devices.devices[1].create_l2cap_server(
spec=ClassicChannelSpec(mtu=345), handler=on_channel
spec=l2cap.ClassicChannelSpec(mtu=345), handler=on_channel
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=ClassicChannelSpec(server.psm, mtu=456)
spec=l2cap.ClassicChannelSpec(server.psm, mtu=456)
)
assert client_channel.peer_mtu == 345

View File

@@ -322,7 +322,9 @@ async def test_self_smp(io_caps, sc, mitm, key_dist):
return 0
else:
if (
self.peer_delegate.io_capability
self.io_capability
== PairingDelegate.IoCapability.KEYBOARD_INPUT_ONLY
and self.peer_delegate.io_capability
== PairingDelegate.IoCapability.KEYBOARD_INPUT_ONLY
):
peer_number = 6789