Compare commits

..

23 Commits

Author SHA1 Message Date
khsiao-google ca23d6b89a Revert "Improve connection related functions and names" 2025-09-10 15:00:41 +08:00
khsiao-google d86d69d816 Merge pull request #771 from khsiao-google/update
Improve connection related functions and names
2025-09-10 14:56:38 +08:00
khsiao-google dc93f32a9a Replace core.ConnectionParameters by Connection.Parameters in device.py 2025-09-08 02:00:49 +00:00
zxzxwu 9838908a26 Merge pull request #772 from zxzxwu/hap
HAP: Slightly Pythonic refactor
2025-09-05 23:08:09 +08:00
Josh Wu 613519f0b3 HAP: Slightly Pythonic refactor
* Add missing type annotations
* Avoid __value__ and _ arguments (this will be a problem for override).
* Replace while-pop with for loop
2025-09-05 21:02:16 +08:00
zxzxwu a943ea57ef Merge pull request #770 from zxzxwu/avrcp
AVRCP: Implement most commands and responses
2025-09-04 16:18:54 +08:00
Josh Wu 14401910bb AVRCP: Implement most commands and responses 2025-09-03 13:20:10 +08:00
khsiao-google 5d35ed471c Merge pull request #769 from khsiao-google/update
Add typing for host.py
2025-09-02 14:59:27 +08:00
khsiao-google c720ad5fdc Add typing for host.py 2025-09-02 06:01:39 +00:00
khsiao-google f02183f95d Merge pull request #764 from khsiao-google/update
Add typing for device.py
2025-09-01 15:19:57 +08:00
khsiao-google d903937a51 Merge branch 'main' into update 2025-09-01 07:14:19 +00:00
zxzxwu 6381ee0ab1 Merge pull request #767 from zxzxwu/avrcp
Migrate AVRCP packets to dataclasses
2025-09-01 13:26:56 +08:00
Gilles Boccon-Gibod 59d99780e1 Merge pull request #768 from google/gbg/data-types
add support for data type classes
2025-08-30 13:04:32 -07:00
Gilles Boccon-Gibod 4bf0bc03af more python compat 2025-08-30 12:13:34 -07:00
Gilles Boccon-Gibod 91ba2f61f1 python 3.9 and 3.10 compatibility 2025-08-30 12:07:08 -07:00
Gilles Boccon-Gibod 116dc9b319 add support for data type classes 2025-08-29 13:17:17 -07:00
Josh Wu 9f3d8c9b49 Migrate AVRCP responses to dataclasses 2025-08-28 21:42:38 +08:00
Josh Wu 31961febe5 Migrate AVRCP events to dataclasses 2025-08-28 17:00:20 +08:00
Josh Wu dab0993cba Migrate AVRCP packets to dataclasses 2025-08-28 17:00:20 +08:00
zxzxwu 6f73b736d7 Merge pull request #766 from zxzxwu/l2cap
Remove depreacated L2CAP APIs
2025-08-28 10:58:35 +08:00
Josh Wu 6091e6365d Remove depreacated L2CAP APIs 2025-08-27 14:15:08 +08:00
khsiao-google 3333ba472b Add typing for device.py 2025-08-26 09:22:06 +00:00
Gilles Boccon-Gibod 8bda7d2212 Merge pull request #763 from google/gbg/isort 2025-08-22 13:50:27 -07:00
31 changed files with 3582 additions and 1429 deletions
+4 -1
View File
@@ -104,5 +104,8 @@
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python-envs.defaultEnvManager": "ms-python.python:system",
"python-envs.pythonProjects": []
"python-envs.pythonProjects": [],
"nrf-connect.applications": [
"${workspaceFolder}/extras/zephyr/hci_usb"
]
}
+8 -21
View File
@@ -39,7 +39,7 @@ import bumble.device
import bumble.logging
import bumble.transport
import bumble.utils
from bumble import company_ids, core, gatt, hci
from bumble import company_ids, core, data_types, gatt, hci
from bumble.audio import io as audio_io
from bumble.colors import color
from bumble.profiles import bap, bass, le_audio, pbp
@@ -859,21 +859,13 @@ async def run_transmit(
)
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
advertising_manufacturer_data = (
b''
if manufacturer_data is None
else bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
struct.pack('<H', manufacturer_data[0])
+ manufacturer_data[1],
)
]
)
advertising_data_types: list[core.DataType] = [
data_types.BroadcastName(broadcast_name)
]
if manufacturer_data is not None:
advertising_data_types.append(
data_types.ManufacturerSpecificData(*manufacturer_data)
)
)
advertising_set = await device.create_advertising_set(
advertising_parameters=bumble.device.AdvertisingParameters(
@@ -885,12 +877,7 @@ async def run_transmit(
),
advertising_data=(
broadcast_audio_announcement.get_advertising_data()
+ bytes(
core.AdvertisingData(
[(core.AdvertisingData.BROADCAST_NAME, broadcast_name.encode())]
)
)
+ advertising_manufacturer_data
+ bytes(core.AdvertisingData(advertising_data_types))
),
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
periodic_advertising_interval_min=80,
+7 -16
View File
@@ -37,7 +37,7 @@ import click
import bumble
import bumble.logging
from bumble import utils
from bumble import data_types, utils
from bumble.colors import color
from bumble.core import AdvertisingData
from bumble.device import AdvertisingParameters, CisLink, Device, DeviceConfiguration
@@ -330,22 +330,13 @@ class Speaker:
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
data_types.CompleteLocalName(device_config.name),
data_types.Flags(
AdvertisingData.Flags.LE_GENERAL_DISCOVERABLE_MODE
| AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_NOT_SUPPORTED_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
data_types.IncompleteListOf16BitServiceUUIDs(
[pacs.PublishedAudioCapabilitiesService.UUID]
),
]
)
+14 -27
View File
@@ -23,6 +23,7 @@ import struct
import click
from prompt_toolkit.shortcuts import PromptSession
from bumble import data_types
from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
@@ -34,6 +35,7 @@ from bumble.core import (
UUID,
AdvertisingData,
Appearance,
DataType,
PhysicalTransport,
ProtocolError,
)
@@ -506,33 +508,21 @@ async def pair(
if mode == 'dual':
flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE
ad_structs = [
(
AdvertisingData.FLAGS,
bytes([flags]),
),
(AdvertisingData.COMPLETE_LOCAL_NAME, 'Bumble'.encode()),
advertising_data_types: list[DataType] = [
data_types.Flags(flags),
data_types.CompleteLocalName('Bumble'),
]
if service_uuids_16:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_16),
)
advertising_data_types.append(
data_types.IncompleteListOf16BitServiceUUIDs(service_uuids_16)
)
if service_uuids_32:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_32),
)
advertising_data_types.append(
data_types.IncompleteListOf32BitServiceUUIDs(service_uuids_32)
)
if service_uuids_128:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_128),
)
advertising_data_types.append(
data_types.IncompleteListOf128BitServiceUUIDs(service_uuids_128)
)
if advertise_appearance:
@@ -559,13 +549,10 @@ async def pair(
advertise_appearance_int = int(
Appearance(category_enum, subcategory_enum)
)
ad_structs.append(
(
AdvertisingData.APPEARANCE,
struct.pack('<H', advertise_appearance_int),
)
advertising_data_types.append(
data_types.Appearance(category_enum, subcategory_enum)
)
device.advertising_data = bytes(AdvertisingData(ad_structs))
device.advertising_data = bytes(AdvertisingData(advertising_data_types))
await device.start_advertising(
auto_restart=True,
own_address_type=(
+11 -1
View File
@@ -20,6 +20,7 @@ import asyncio
import click
import bumble.logging
from bumble import data_types
from bumble.colors import color
from bumble.device import Advertisement, Device
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
@@ -94,13 +95,22 @@ class AdvertisementPrinter:
else:
phy_info = ''
details = separator.join(
[
data_type.to_string(use_label=True)
for data_type in data_types.data_types_from_advertising_data(
advertisement.data
)
]
)
print(
f'>>> {color(address, address_color)} '
f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'
f'{details}\n'
)
def on_advertisement(self, advertisement):
+1126 -613
View File
File diff suppressed because it is too large Load Diff
+573 -237
View File
File diff suppressed because it is too large Load Diff
+1025
View File
File diff suppressed because it is too large Load Diff
+124 -116
View File
@@ -45,7 +45,19 @@ from typing import (
from typing_extensions import Self
from bumble import core, gatt_client, gatt_server, hci, l2cap, pairing, sdp, smp, utils
from bumble import (
core,
data_types,
gatt,
gatt_client,
gatt_server,
hci,
l2cap,
pairing,
sdp,
smp,
utils,
)
from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from bumble.colors import color
from bumble.core import (
@@ -253,7 +265,7 @@ class ExtendedAdvertisement(Advertisement):
# -----------------------------------------------------------------------------
class AdvertisementDataAccumulator:
def __init__(self, passive=False):
def __init__(self, passive: bool = False):
self.passive = passive
self.last_advertisement = None
self.last_data = b''
@@ -1232,7 +1244,7 @@ class LePhyOptions:
PREFER_S_2_CODED_PHY = 1
PREFER_S_8_CODED_PHY = 2
def __init__(self, coded_phy_preference=0):
def __init__(self, coded_phy_preference: int = 0):
self.coded_phy_preference = coded_phy_preference
def __int__(self):
@@ -1680,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter):
self_address: hci.Address
self_resolvable_address: Optional[hci.Address]
peer_address: hci.Address
peer_name: Optional[str]
peer_resolvable_address: Optional[hci.Address]
peer_le_features: Optional[hci.LeFeatureMask]
role: hci.Role
@@ -1869,16 +1882,6 @@ class Connection(utils.CompositeEventEmitter):
def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(self.handle, cid, pdu)
@utils.deprecated("Please use create_l2cap_channel()")
async def open_l2cap_channel(
self,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS,
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
@overload
async def create_l2cap_channel(
self, spec: l2cap.ClassicChannelSpec
@@ -1930,7 +1933,7 @@ class Connection(utils.CompositeEventEmitter):
self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result)
self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception)
async def set_data_length(self, tx_octets, tx_time) -> None:
async def set_data_length(self, tx_octets: int, tx_time: int) -> None:
return await self.device.set_data_length(self, tx_octets, tx_time)
async def update_parameters(
@@ -1960,7 +1963,12 @@ class Connection(utils.CompositeEventEmitter):
use_l2cap=use_l2cap,
)
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
async def set_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options)
async def get_phy(self) -> ConnectionPHY:
@@ -2059,9 +2067,7 @@ class DeviceConfiguration:
connectable: bool = True
discoverable: bool = True
advertising_data: bytes = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))]
)
AdvertisingData([data_types.CompleteLocalName(DEVICE_DEFAULT_NAME)])
)
irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None
@@ -2105,9 +2111,7 @@ class DeviceConfiguration:
self.advertising_data = bytes.fromhex(advertising_data)
elif name is not None:
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
AdvertisingData([data_types.CompleteLocalName(self.name)])
)
# Load scan response data
@@ -2159,7 +2163,7 @@ class DeviceConfiguration:
# Decorator that converts the first argument from a connection handle to a connection
def with_connection_from_handle(function):
@functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs):
def wrapper(self, connection_handle: int, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None:
raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
@@ -2172,7 +2176,7 @@ def with_connection_from_handle(function):
# Decorator that converts the first argument from a bluetooth address to a connection
def with_connection_from_address(function):
@functools.wraps(function)
def wrapper(self, address, *args, **kwargs):
def wrapper(self, address: hci.Address, *args, **kwargs):
if connection := self.pending_connections.get(address, False):
return function(self, connection, *args, **kwargs)
for connection in self.connections.values():
@@ -2591,36 +2595,6 @@ class Device(utils.CompositeEventEmitter):
None,
)
@utils.deprecated("Please use create_l2cap_server()")
def register_l2cap_server(self, psm, server) -> int:
return self.l2cap_channel_manager.register_server(psm, server)
@utils.deprecated("Please use create_l2cap_server()")
def register_l2cap_channel_server(
self,
psm,
server,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS,
):
return self.l2cap_channel_manager.register_le_coc_server(
psm, server, max_credits, mtu, mps
)
@utils.deprecated("Please use create_l2cap_channel()")
async def open_l2cap_channel(
self,
connection,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS,
):
return await self.l2cap_channel_manager.open_le_coc(
connection, psm, max_credits, mtu, mps
)
@overload
async def create_l2cap_channel(
self,
@@ -2688,7 +2662,7 @@ class Device(utils.CompositeEventEmitter):
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
async def send_command(self, command, check_result=False):
async def send_command(self, command: hci.HCI_Command, check_result: bool = False):
try:
return await asyncio.wait_for(
self.host.send_command(command, check_result), self.command_timeout
@@ -2945,13 +2919,13 @@ class Device(utils.CompositeEventEmitter):
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
return self.host.supports_le_features(feature)
def supports_le_phy(self, phy: int) -> bool:
if phy == hci.HCI_LE_1M_PHY:
def supports_le_phy(self, phy: hci.Phy) -> bool:
if phy == hci.Phy.LE_1M:
return True
feature_map: dict[int, hci.LeFeatureMask] = {
hci.HCI_LE_2M_PHY: hci.LeFeatureMask.LE_2M_PHY,
hci.HCI_LE_CODED_PHY: hci.LeFeatureMask.LE_CODED_PHY,
feature_map: dict[hci.Phy, hci.LeFeatureMask] = {
hci.Phy.LE_2M: hci.LeFeatureMask.LE_2M_PHY,
hci.Phy.LE_CODED: hci.LeFeatureMask.LE_CODED_PHY,
}
if phy not in feature_map:
raise InvalidArgumentError('invalid PHY')
@@ -3555,7 +3529,9 @@ class Device(utils.CompositeEventEmitter):
self.discovering = False
@host_event_handler
def on_inquiry_result(self, address, class_of_device, data, rssi):
def on_inquiry_result(
self, address: hci.Address, class_of_device: int, data: bytes, rssi: int
):
self.emit(
self.EVENT_INQUIRY_RESULT,
address,
@@ -3564,7 +3540,9 @@ class Device(utils.CompositeEventEmitter):
rssi,
)
async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled):
async def set_scan_enable(
self, inquiry_scan_enabled: bool, page_scan_enabled: bool
):
if inquiry_scan_enabled and page_scan_enabled:
scan_enable = 0x03
elif page_scan_enabled:
@@ -3584,14 +3562,7 @@ class Device(utils.CompositeEventEmitter):
# Synthesize an inquiry response if none is set already
if self.inquiry_response is None:
self.inquiry_response = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(self.name, 'utf-8'),
)
]
)
AdvertisingData([data_types.CompleteLocalName(self.name)])
)
# Update the controller
@@ -3697,6 +3668,7 @@ class Device(utils.CompositeEventEmitter):
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, transport
) # TODO: timeout
@@ -3724,7 +3696,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if transport == PhysicalTransport.LE or (
# match BR/EDR connection failure event against peer address
error.transport == transport
@@ -3944,6 +3916,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -4002,7 +3975,7 @@ class Device(utils.CompositeEventEmitter):
):
pending_connection.set_result(connection)
def on_connection_failure(error):
def on_connection_failure(error: core.ConnectionError):
if (
error.transport == PhysicalTransport.BR_EDR
and error.peer_address == peer_address
@@ -4039,7 +4012,7 @@ class Device(utils.CompositeEventEmitter):
self.pending_connections.pop(peer_address, None)
@asynccontextmanager
async def connect_as_gatt(self, peer_address):
async def connect_as_gatt(self, peer_address: Union[hci.Address, str]):
async with AsyncExitStack() as stack:
connection = await stack.enter_async_context(
await self.connect(peer_address)
@@ -4075,6 +4048,7 @@ class Device(utils.CompositeEventEmitter):
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
@@ -4120,7 +4094,9 @@ class Device(utils.CompositeEventEmitter):
)
self.disconnecting = False
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
async def set_data_length(
self, connection: Connection, tx_octets: int, tx_time: int
) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB:
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
@@ -4223,7 +4199,11 @@ class Device(utils.CompositeEventEmitter):
)
async def set_connection_phy(
self, connection, tx_phys=None, rx_phys=None, phy_options=None
self,
connection: Connection,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
phy_options: int = 0,
):
if not self.host.supports_command(hci.HCI_LE_SET_PHY_COMMAND):
logger.warning('ignoring request, command not supported')
@@ -4239,7 +4219,7 @@ class Device(utils.CompositeEventEmitter):
all_phys=all_phys_bits,
tx_phys=hci.phy_list_to_bits(tx_phys),
rx_phys=hci.phy_list_to_bits(rx_phys),
phy_options=0 if phy_options is None else int(phy_options),
phy_options=phy_options,
)
)
@@ -4250,7 +4230,11 @@ class Device(utils.CompositeEventEmitter):
)
raise hci.HCI_StatusError(result)
async def set_default_phy(self, tx_phys=None, rx_phys=None):
async def set_default_phy(
self,
tx_phys: Optional[Iterable[hci.Phy]] = None,
rx_phys: Optional[Iterable[hci.Phy]] = None,
):
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
)
@@ -4288,7 +4272,7 @@ class Device(utils.CompositeEventEmitter):
check_result=True,
)
async def find_peer_by_name(self, name, transport=PhysicalTransport.LE):
async def find_peer_by_name(self, name: str, transport=PhysicalTransport.LE):
"""
Scan for a peer with a given name and return its address.
"""
@@ -4303,7 +4287,7 @@ class Device(utils.CompositeEventEmitter):
if local_name == name:
peer_address.set_result(address)
listener = None
listener: Optional[Callable[..., None]] = None
was_scanning = self.scanning
was_discovering = self.discovering
try:
@@ -4409,10 +4393,10 @@ class Device(utils.CompositeEventEmitter):
def smp_session_proxy(self, session_proxy: type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy
async def pair(self, connection):
async def pair(self, connection: Connection):
return await self.smp_manager.pair(connection)
def request_pairing(self, connection):
def request_pairing(self, connection: Connection):
return self.smp_manager.request_pairing(connection)
async def get_long_term_key(
@@ -4500,7 +4484,7 @@ class Device(utils.CompositeEventEmitter):
on_authentication_failure,
)
async def encrypt(self, connection, enable=True):
async def encrypt(self, connection: Connection, enable: bool = True):
if not enable and connection.transport == PhysicalTransport.LE:
raise InvalidArgumentError('`enable` parameter is classic only.')
@@ -4510,7 +4494,7 @@ class Device(utils.CompositeEventEmitter):
def on_encryption_change():
pending_encryption.set_result(None)
def on_encryption_failure(error_code):
def on_encryption_failure(error_code: int):
pending_encryption.set_exception(hci.HCI_Error(error_code))
connection.on(
@@ -4602,10 +4586,10 @@ class Device(utils.CompositeEventEmitter):
async def switch_role(self, connection: Connection, role: hci.Role):
pending_role_change = asyncio.get_running_loop().create_future()
def on_role_change(new_role):
def on_role_change(new_role: hci.Role):
pending_role_change.set_result(new_role)
def on_role_change_failure(error_code):
def on_role_change_failure(error_code: int):
pending_role_change.set_exception(hci.HCI_Error(error_code))
connection.on(connection.EVENT_ROLE_CHANGE, on_role_change)
@@ -5195,10 +5179,10 @@ class Device(utils.CompositeEventEmitter):
):
connection.emit(connection.EVENT_LINK_KEY)
def add_service(self, service):
def add_service(self, service: gatt.Service):
self.gatt_server.add_service(service)
def add_services(self, services):
def add_services(self, services: Iterable[gatt.Service]):
self.gatt_server.add_services(services)
def add_default_services(
@@ -5294,10 +5278,10 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
def on_advertising_set_termination(
self,
status,
advertising_handle,
connection_handle,
number_of_completed_extended_advertising_events,
status: int,
advertising_handle: int,
connection_handle: int,
number_of_completed_extended_advertising_events: int,
):
# Legacy advertising set is also one of extended advertising sets.
if not (
@@ -5596,7 +5580,12 @@ class Device(utils.CompositeEventEmitter):
)
@host_event_handler
def on_connection_failure(self, transport, peer_address, error_code):
def on_connection_failure(
self,
transport: hci.PhysicalTransport,
peer_address: hci.Address,
error_code: int,
):
logger.debug(
f'*** Connection failed: {hci.HCI_Constant.error_name(error_code)}'
)
@@ -5715,7 +5704,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication(self, connection):
def on_connection_authentication(self, connection: Connection):
logger.debug(
f'*** Connection Authentication: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -5725,7 +5714,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_authentication_failure(self, connection, error):
def on_connection_authentication_failure(
self, connection: Connection, error: core.ConnectionError
):
logger.debug(
f'*** Connection Authentication Failure: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, error={error}'
@@ -5767,10 +5758,13 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_address
def on_authentication_io_capability_response(
self, connection, io_capability, authentication_requirements
self,
connection: Connection,
io_capability: int,
authentication_requirements: int,
):
connection.peer_pairing_io_capability = io_capability
connection.peer_pairing_authentication_requirements = (
connection.pairing_peer_io_capability = io_capability
connection.pairing_peer_authentication_requirements = (
authentication_requirements
)
@@ -5781,7 +5775,7 @@ class Device(utils.CompositeEventEmitter):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
peer_io_capability = connection.peer_pairing_io_capability
peer_io_capability = connection.pairing_peer_io_capability
async def confirm() -> bool:
# Ask the user to confirm the pairing, without display
@@ -5856,7 +5850,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_request(self, connection) -> None:
def on_authentication_user_passkey_request(self, connection: Connection) -> None:
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5899,7 +5893,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_pin_code_request(self, connection):
def on_pin_code_request(self, connection: Connection):
# Classic legacy pairing
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5943,7 +5937,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_notification(self, connection, passkey):
def on_authentication_user_passkey_notification(
self, connection: Connection, passkey: int
):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -5955,14 +5951,15 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name(self, connection: Connection, address, remote_name):
def on_remote_name(
self, connection: Connection, address: hci.Address, remote_name: bytes
):
# Try to decode the name
try:
remote_name = remote_name.decode('utf-8')
if connection:
connection.peer_name = remote_name
connection.peer_name = remote_name.decode('utf-8')
connection.emit(connection.EVENT_REMOTE_NAME)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name)
self.emit(self.EVENT_REMOTE_NAME, address, remote_name.decode('utf-8'))
except UnicodeDecodeError as error:
logger.warning('peer name is not valid UTF-8')
if connection:
@@ -5973,7 +5970,9 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_remote_name_failure(self, connection: Connection, address, error):
def on_remote_name_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
@@ -6174,7 +6173,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_encryption_key_refresh(self, connection):
def on_connection_encryption_key_refresh(self, connection: Connection):
logger.debug(
f'*** Connection Key Refresh: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}'
@@ -6212,7 +6211,9 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_parameters_update_failure(self, connection, error):
def on_connection_parameters_update_failure(
self, connection: Connection, error: int
):
logger.debug(
f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6222,7 +6223,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update(self, connection, phy):
def on_connection_phy_update(self, connection: Connection, phy: core.ConnectionPHY):
logger.debug(
f'*** Connection PHY Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6232,7 +6233,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_phy_update_failure(self, connection, error):
def on_connection_phy_update_failure(self, connection: Connection, error: int):
logger.debug(
f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6261,7 +6262,7 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection, att_mtu):
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
@@ -6273,7 +6274,12 @@ class Device(utils.CompositeEventEmitter):
@host_event_handler
@with_connection_from_handle
def on_connection_data_length_change(
self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time
self,
connection: Connection,
max_tx_octets: int,
max_tx_time: int,
max_rx_octets: int,
max_rx_time: int,
):
logger.debug(
f'*** Connection Data Length Change: [0x{connection.handle:04X}] '
@@ -6398,14 +6404,16 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_role_change(self, connection, new_role):
def on_role_change(self, connection: Connection, new_role: hci.Role):
connection.role = new_role
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
# [Classic only]
@host_event_handler
@try_with_connection_from_address
def on_role_change_failure(self, connection, address, error):
def on_role_change_failure(
self, connection: Connection, address: hci.Address, error: int
):
if connection:
connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error)
self.emit(self.EVENT_ROLE_CHANGE_FAILURE, address, error)
@@ -6419,7 +6427,7 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status) -> None:
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
def on_pairing_start(self, connection: Connection) -> None:
@@ -6443,7 +6451,7 @@ class Device(utils.CompositeEventEmitter):
connection.emit(connection.EVENT_PAIRING_FAILURE, reason)
@with_connection_from_handle
def on_gatt_pdu(self, connection, pdu):
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu)
@@ -6465,7 +6473,7 @@ class Device(utils.CompositeEventEmitter):
connection.gatt_server.on_gatt_pdu(connection, att_pdu)
@with_connection_from_handle
def on_smp_pdu(self, connection, pdu):
def on_smp_pdu(self, connection: Connection, pdu: bytes):
self.smp_manager.on_smp_pdu(connection, pdu)
@host_event_handler
+62 -13
View File
@@ -26,7 +26,17 @@ import secrets
import struct
from collections.abc import Sequence
from dataclasses import field
from typing import Any, Callable, ClassVar, Iterable, Optional, TypeVar, Union, cast
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import Self
@@ -111,23 +121,57 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
class SpecableEnum(utils.OpenIntEnum):
@classmethod
def type_spec(cls, size: int):
return {'size': size, 'mapper': lambda x: cls(x).name}
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
'serializer': lambda x: x.to_bytes(size, byteorder),
'parser': lambda data, offset: (
offset + size,
cls(int.from_bytes(data[offset : offset + size], byteorder)),
),
'mapper': lambda x: cls(x).name,
}
@classmethod
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
def type_metadata(
cls,
size: int,
list_begin: bool = False,
list_end: bool = False,
byteorder: Literal['little', 'big'] = 'little',
):
return metadata(
cls.type_spec(size, byteorder),
list_begin=list_begin,
list_end=list_end,
)
class SpecableFlag(enum.IntFlag):
@classmethod
def type_spec(cls, size: int):
return {'size': size, 'mapper': lambda x: cls(x).name}
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
'serializer': lambda x: x.to_bytes(size, byteorder),
'parser': lambda data, offset: (
offset + size,
cls(int.from_bytes(data[offset : offset + size], byteorder)),
),
'mapper': lambda x: cls(x).name,
}
@classmethod
def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False):
return metadata(cls.type_spec(size), list_begin=list_begin, list_end=list_end)
def type_metadata(
cls,
size: int,
list_begin: bool = False,
list_end: bool = False,
byteorder: Literal['little', 'big'] = 'little',
):
return metadata(
cls.type_spec(size, byteorder),
list_begin=list_begin,
list_end=list_end,
)
# -----------------------------------------------------------------------------
@@ -1322,7 +1366,7 @@ class LeFeature(SpecableEnum):
MONITORING_ADVERTISERS = 64
FRAME_SPACE_UPDATE = 65
class LeFeatureMask(enum.IntFlag):
class LeFeatureMask(utils.CompatibleIntFlag):
LE_ENCRYPTION = 1 << LeFeature.LE_ENCRYPTION
CONNECTION_PARAMETERS_REQUEST_PROCEDURE = 1 << LeFeature.CONNECTION_PARAMETERS_REQUEST_PROCEDURE
EXTENDED_REJECT_INDICATION = 1 << LeFeature.EXTENDED_REJECT_INDICATION
@@ -1463,7 +1507,7 @@ class LmpFeature(SpecableEnum):
SLOT_AVAILABILITY_MASK = 138
TRAIN_NUDGING = 139
class LmpFeatureMask(enum.IntFlag):
class LmpFeatureMask(utils.CompatibleIntFlag):
# Page 0 (Legacy LMP features)
LMP_3_SLOT_PACKETS = (1 << LmpFeature.LMP_3_SLOT_PACKETS)
LMP_5_SLOT_PACKETS = (1 << LmpFeature.LMP_5_SLOT_PACKETS)
@@ -2135,6 +2179,7 @@ class Address:
if len(address) == 12 + 5:
# Form with ':' separators
address = address.replace(':', '')
self.address_bytes = bytes(reversed(bytes.fromhex(address)))
if len(self.address_bytes) != 6:
@@ -6421,7 +6466,9 @@ class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event):
irc: int = field(metadata=metadata(1))
max_pdu: int = field(metadata=metadata(2))
iso_interval: int = field(metadata=metadata(2))
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
connection_handle: Sequence[int] = field(
metadata=metadata(2, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------
@@ -6453,7 +6500,9 @@ class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event):
irc: int = field(metadata=metadata(1))
max_pdu: int = field(metadata=metadata(2))
iso_interval: int = field(metadata=metadata(2))
connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True))
connection_handle: Sequence[int] = field(
metadata=metadata(2, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------
+16 -8
View File
@@ -217,33 +217,41 @@ class HID(ABC, utils.EventEmitter):
self.role = role
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.create_l2cap_server(
l2cap.ClassicChannelSpec(HID_CONTROL_PSM), self.on_l2cap_connection
)
device.create_l2cap_server(
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM), self.on_l2cap_connection
)
device.on(device.EVENT_CONNECTION, self.on_device_connection)
async def connect_control_channel(self) -> None:
if not self.connection:
raise InvalidStateError("Connection is not established!")
# Create a new L2CAP connection - control channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
channel = await self.connection.create_l2cap_channel(
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
logging.exception('L2CAP connection failed.')
raise
async def connect_interrupt_channel(self) -> None:
if not self.connection:
raise InvalidStateError("Connection is not established!")
# Create a new L2CAP connection - interrupt channel
try:
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
channel = await self.connection.create_l2cap_channel(
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
logging.exception('L2CAP connection failed.')
raise
async def disconnect_interrupt_channel(self) -> None:
+173 -64
View File
@@ -22,11 +22,16 @@ import collections
import dataclasses
import logging
import struct
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast
from bumble import drivers, hci, utils
from bumble.colors import color
from bumble.core import ConnectionParameters, ConnectionPHY, PhysicalTransport
from bumble.core import (
ConnectionParameters,
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError
@@ -902,10 +907,14 @@ class Host(utils.EventEmitter):
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(self, event):
def on_command_processed(
self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event]
):
if self.pending_response:
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
if self.pending_command is None:
logger.warning('!!! pending_command is None ')
elif self.pending_command.op_code != event.command_opcode:
logger.warning(
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
@@ -919,10 +928,10 @@ class Host(utils.EventEmitter):
############################################################
# HCI handlers
############################################################
def on_hci_event(self, event):
def on_hci_event(self, event: hci.HCI_Event):
logger.warning(f'{color(f"--- Ignoring event {event}", "red")}')
def on_hci_command_complete_event(self, event):
def on_hci_command_complete_event(self, event: hci.HCI_Command_Complete_Event):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
@@ -931,7 +940,7 @@ class Host(utils.EventEmitter):
return self.on_command_processed(event)
def on_hci_command_status_event(self, event):
def on_hci_command_status_event(self, event: hci.HCI_Command_Status_Event):
return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(
@@ -951,7 +960,7 @@ class Host(utils.EventEmitter):
)
# Classic only
def on_hci_connection_request_event(self, event):
def on_hci_connection_request_event(self, event: hci.HCI_Connection_Request_Event):
# Notify the listeners
self.emit(
'connection_request',
@@ -960,7 +969,14 @@ class Host(utils.EventEmitter):
event.link_type,
)
def on_hci_le_connection_complete_event(self, event):
def on_hci_le_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Check if this is a cancellation
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
@@ -1006,15 +1022,25 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_le_enhanced_connection_complete_event(self, event):
def on_hci_le_enhanced_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Just use the same implementation as for the non-enhanced event for now
self.on_hci_le_connection_complete_event(event)
def on_hci_le_enhanced_connection_complete_v2_event(self, event):
def on_hci_le_enhanced_connection_complete_v2_event(
self, event: hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
):
# Just use the same implementation as for the v1 event for now
self.on_hci_le_enhanced_connection_complete_event(event)
def on_hci_connection_complete_event(self, event):
def on_hci_connection_complete_event(
self, event: hci.HCI_Connection_Complete_Event
):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
@@ -1054,7 +1080,9 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_disconnection_complete_event(self, event):
def on_hci_disconnection_complete_event(
self, event: hci.HCI_Disconnection_Complete_Event
):
# Find the connection
handle = event.connection_handle
if (
@@ -1093,7 +1121,9 @@ class Host(utils.EventEmitter):
# Notify the listeners
self.emit('disconnection_failure', handle, event.status)
def on_hci_le_connection_update_complete_event(self, event):
def on_hci_le_connection_update_complete_event(
self, event: hci.HCI_LE_Connection_Update_Complete_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
return
@@ -1113,7 +1143,9 @@ class Host(utils.EventEmitter):
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event):
def on_hci_le_phy_update_complete_event(
self, event: hci.HCI_LE_PHY_Update_Complete_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle')
return
@@ -1143,7 +1175,9 @@ class Host(utils.EventEmitter):
):
self.on_hci_le_advertising_report_event(event)
def on_hci_le_advertising_set_terminated_event(self, event):
def on_hci_le_advertising_set_terminated_event(
self, event: hci.HCI_LE_Advertising_Set_Terminated_Event
):
self.emit(
'advertising_set_termination',
event.status,
@@ -1152,7 +1186,9 @@ class Host(utils.EventEmitter):
event.num_completed_extended_advertising_events,
)
def on_hci_le_periodic_advertising_sync_established_event(self, event):
def on_hci_le_periodic_advertising_sync_established_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_Event
):
self.emit(
'periodic_advertising_sync_establishment',
event.status,
@@ -1164,16 +1200,22 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_lost_event(self, event):
def on_hci_le_periodic_advertising_sync_lost_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event
):
self.emit('periodic_advertising_sync_loss', event.sync_handle)
def on_hci_le_periodic_advertising_report_event(self, event):
def on_hci_le_periodic_advertising_report_event(
self, event: hci.HCI_LE_Periodic_Advertising_Report_Event
):
self.emit('periodic_advertising_report', event.sync_handle, event)
def on_hci_le_biginfo_advertising_report_event(self, event):
def on_hci_le_biginfo_advertising_report_event(
self, event: hci.HCI_LE_BIGInfo_Advertising_Report_Event
):
self.emit('biginfo_advertising_report', event.sync_handle, event)
def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_request_event(self, event: hci.HCI_LE_CIS_Request_Event):
self.emit(
'cis_request',
event.acl_connection_handle,
@@ -1182,10 +1224,12 @@ class Host(utils.EventEmitter):
event.cis_id,
)
def on_hci_le_create_big_complete_event(self, event):
def on_hci_le_create_big_complete_event(
self, event: hci.HCI_LE_Create_BIG_Complete_Event
):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
logger.warning("BIS established but ISO packets not supported")
raise InvalidStateError("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
@@ -1208,8 +1252,13 @@ class Host(utils.EventEmitter):
event.iso_interval,
)
def on_hci_le_big_sync_established_event(self, event):
def on_hci_le_big_sync_established_event(
self, event: hci.HCI_LE_BIG_Sync_Established_Event
):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
raise InvalidStateError("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
@@ -1229,15 +1278,19 @@ class Host(utils.EventEmitter):
event.connection_handle,
)
def on_hci_le_big_sync_lost_event(self, event):
def on_hci_le_big_sync_lost_event(self, event: hci.HCI_LE_BIG_Sync_Lost_Event):
self.remove_big(event.big_handle)
self.emit('big_sync_lost', event.big_handle, event.reason)
def on_hci_le_terminate_big_complete_event(self, event):
def on_hci_le_terminate_big_complete_event(
self, event: hci.HCI_LE_Terminate_BIG_Complete_Event
):
self.remove_big(event.big_handle)
self.emit('big_termination', event.reason, event.big_handle)
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
def on_hci_le_periodic_advertising_sync_transfer_received_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event
):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
@@ -1250,7 +1303,9 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_V2_Event
):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
@@ -1263,11 +1318,11 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_cis_established_event(self, event):
def on_hci_le_cis_established_event(self, event: hci.HCI_LE_CIS_Established_Event):
# The remaining parameters are unused for now.
if event.status == hci.HCI_SUCCESS:
if self.iso_packet_queue is None:
logger.warning("CIS established but ISO packets not supported")
raise InvalidStateError("CIS established but ISO packets not supported")
self.cis_links[event.connection_handle] = IsoLink(
handle=event.connection_handle, packet_queue=self.iso_packet_queue
)
@@ -1294,7 +1349,9 @@ class Host(utils.EventEmitter):
'cis_establishment_failure', event.connection_handle, event.status
)
def on_hci_le_remote_connection_parameter_request_event(self, event):
def on_hci_le_remote_connection_parameter_request_event(
self, event: hci.HCI_LE_Remote_Connection_Parameter_Request_Event
):
if event.connection_handle not in self.connections:
logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle')
return
@@ -1313,7 +1370,9 @@ class Host(utils.EventEmitter):
)
)
def on_hci_le_long_term_key_request_event(self, event):
def on_hci_le_long_term_key_request_event(
self, event: hci.HCI_LE_Long_Term_Key_Request_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle')
return
@@ -1347,7 +1406,9 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_long_term_key())
def on_hci_synchronous_connection_complete_event(self, event):
def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event
):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
@@ -1373,7 +1434,9 @@ class Host(utils.EventEmitter):
# Notify the client
self.emit('sco_connection_failure', event.bd_addr, event.status)
def on_hci_synchronous_connection_changed_event(self, event):
def on_hci_synchronous_connection_changed_event(
self, event: hci.HCI_Synchronous_Connection_Changed_Event
):
pass
def on_hci_mode_change_event(self, event: hci.HCI_Mode_Change_Event):
@@ -1385,7 +1448,7 @@ class Host(utils.EventEmitter):
event.interval,
)
def on_hci_role_change_event(self, event):
def on_hci_role_change_event(self, event: hci.HCI_Role_Change_Event):
if event.status == hci.HCI_SUCCESS:
logger.debug(
f'role change for {event.bd_addr}: '
@@ -1399,7 +1462,9 @@ class Host(utils.EventEmitter):
)
self.emit('role_change_failure', event.bd_addr, event.status)
def on_hci_le_data_length_change_event(self, event):
def on_hci_le_data_length_change_event(
self, event: hci.HCI_LE_Data_Length_Change_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! DATA LENGTH CHANGE: unknown handle')
return
@@ -1413,7 +1478,9 @@ class Host(utils.EventEmitter):
event.max_rx_time,
)
def on_hci_authentication_complete_event(self, event):
def on_hci_authentication_complete_event(
self, event: hci.HCI_Authentication_Complete_Event
):
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle)
@@ -1454,7 +1521,9 @@ class Host(utils.EventEmitter):
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event):
def on_hci_encryption_key_refresh_complete_event(
self, event: hci.HCI_Encryption_Key_Refresh_Complete_Event
):
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle)
@@ -1465,7 +1534,7 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_qos_setup_complete_event(self, event):
def on_hci_qos_setup_complete_event(self, event: hci.HCI_QOS_Setup_Complete_Event):
if event.status == hci.HCI_SUCCESS:
self.emit(
'connection_qos_setup', event.connection_handle, event.service_type
@@ -1477,23 +1546,31 @@ class Host(utils.EventEmitter):
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
def on_hci_link_supervision_timeout_changed_event(
self, event: hci.HCI_Link_Supervision_Timeout_Changed_Event
):
pass
def on_hci_max_slots_change_event(self, event):
def on_hci_max_slots_change_event(self, event: hci.HCI_Max_Slots_Change_Event):
pass
def on_hci_page_scan_repetition_mode_change_event(self, event):
def on_hci_page_scan_repetition_mode_change_event(
self, event: hci.HCI_Page_Scan_Repetition_Mode_Change_Event
):
pass
def on_hci_link_key_notification_event(self, event):
def on_hci_link_key_notification_event(
self, event: hci.HCI_Link_Key_Notification_Event
):
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
def on_hci_simple_pairing_complete_event(
self, event: hci.HCI_Simple_Pairing_Complete_Event
):
logger.debug(
f'simple pairing complete for {event.bd_addr}: '
f'status={hci.HCI_Constant.status_name(event.status)}'
@@ -1503,10 +1580,10 @@ class Host(utils.EventEmitter):
else:
self.emit('classic_pairing_failure', event.bd_addr, event.status)
def on_hci_pin_code_request_event(self, event):
def on_hci_pin_code_request_event(self, event: hci.HCI_PIN_Code_Request_Event):
self.emit('pin_code_request', event.bd_addr)
def on_hci_link_key_request_event(self, event):
def on_hci_link_key_request_event(self, event: hci.HCI_Link_Key_Request_Event):
async def send_link_key():
if self.link_key_provider is None:
logger.debug('no link key provider')
@@ -1531,10 +1608,14 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_link_key())
def on_hci_io_capability_request_event(self, event):
def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event
):
self.emit('authentication_io_capability_request', event.bd_addr)
def on_hci_io_capability_response_event(self, event):
def on_hci_io_capability_response_event(
self, event: hci.HCI_IO_Capability_Response_Event
):
self.emit(
'authentication_io_capability_response',
event.bd_addr,
@@ -1542,25 +1623,33 @@ class Host(utils.EventEmitter):
event.authentication_requirements,
)
def on_hci_user_confirmation_request_event(self, event):
def on_hci_user_confirmation_request_event(
self, event: hci.HCI_User_Confirmation_Request_Event
):
self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event):
def on_hci_user_passkey_request_event(
self, event: hci.HCI_User_Passkey_Request_Event
):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
def on_hci_user_passkey_notification_event(
self, event: hci.HCI_User_Passkey_Notification_Event
):
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, _event):
def on_hci_inquiry_complete_event(self, _event: hci.HCI_Inquiry_Complete_Event):
self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event):
def on_hci_inquiry_result_with_rssi_event(
self, event: hci.HCI_Inquiry_Result_With_RSSI_Event
):
for bd_addr, class_of_device, rssi in zip(
event.bd_addr, event.class_of_device, event.rssi
):
@@ -1572,7 +1661,9 @@ class Host(utils.EventEmitter):
rssi,
)
def on_hci_extended_inquiry_result_event(self, event):
def on_hci_extended_inquiry_result_event(
self, event: hci.HCI_Extended_Inquiry_Result_Event
):
self.emit(
'inquiry_result',
event.bd_addr,
@@ -1581,7 +1672,9 @@ class Host(utils.EventEmitter):
event.rssi,
)
def on_hci_remote_name_request_complete_event(self, event):
def on_hci_remote_name_request_complete_event(
self, event: hci.HCI_Remote_Name_Request_Complete_Event
):
if event.status != hci.HCI_SUCCESS:
self.emit('remote_name_failure', event.bd_addr, event.status)
else:
@@ -1592,14 +1685,18 @@ class Host(utils.EventEmitter):
self.emit('remote_name', event.bd_addr, utf8_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
def on_hci_remote_host_supported_features_notification_event(
self, event: hci.HCI_Remote_Host_Supported_Features_Notification_Event
):
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)
def on_hci_le_read_remote_features_complete_event(self, event):
def on_hci_le_read_remote_features_complete_event(
self, event: hci.HCI_LE_Read_Remote_Features_Complete_Event
):
if event.status != hci.HCI_SUCCESS:
self.emit(
'le_remote_features_failure', event.connection_handle, event.status
@@ -1611,22 +1708,34 @@ class Host(utils.EventEmitter):
int.from_bytes(event.le_features, 'little'),
)
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(
self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event
):
self.emit('cs_remote_supported_capabilities', event)
def on_hci_le_cs_security_enable_complete_event(self, event):
def on_hci_le_cs_security_enable_complete_event(
self, event: hci.HCI_LE_CS_Security_Enable_Complete_Event
):
self.emit('cs_security', event)
def on_hci_le_cs_config_complete_event(self, event):
def on_hci_le_cs_config_complete_event(
self, event: hci.HCI_LE_CS_Config_Complete_Event
):
self.emit('cs_config', event)
def on_hci_le_cs_procedure_enable_complete_event(self, event):
def on_hci_le_cs_procedure_enable_complete_event(
self, event: hci.HCI_LE_CS_Procedure_Enable_Complete_Event
):
self.emit('cs_procedure', event)
def on_hci_le_cs_subevent_result_event(self, event):
def on_hci_le_cs_subevent_result_event(
self, event: hci.HCI_LE_CS_Subevent_Result_Event
):
self.emit('cs_subevent_result', event)
def on_hci_le_cs_subevent_result_continue_event(self, event):
def on_hci_le_cs_subevent_result_continue_event(
self, event: hci.HCI_LE_CS_Subevent_Result_Continue_Event
):
self.emit('cs_subevent_result_continue', event)
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
@@ -1639,5 +1748,5 @@ class Host(utils.EventEmitter):
event.supervision_timeout,
)
def on_hci_vendor_event(self, event):
def on_hci_vendor_event(self, event: hci.HCI_Vendor_Event):
self.emit('vendor_event', event)
-60
View File
@@ -1531,16 +1531,6 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@utils.deprecated("Please use create_classic_server")
def register_server(
self,
psm: int,
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
).psm
def create_classic_server(
self,
spec: ClassicChannelSpec,
@@ -1577,22 +1567,6 @@ class ChannelManager:
return self.servers[spec.psm]
@utils.deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
self,
psm: int,
server: Callable[[LeCreditBasedChannel], Any],
max_credits: int,
mtu: int,
mps: int,
) -> int:
return self.create_le_credit_based_server(
spec=LeCreditBasedChannelSpec(
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
),
handler=server,
).psm
def create_le_credit_based_server(
self,
spec: LeCreditBasedChannelSpec,
@@ -2145,17 +2119,6 @@ class ChannelManager:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
@utils.deprecated("Please use create_le_credit_based_channel()")
async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeCreditBasedChannel:
return await self.create_le_credit_based_channel(
connection=connection,
spec=LeCreditBasedChannelSpec(
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
),
)
async def create_le_credit_based_channel(
self,
connection: Connection,
@@ -2202,12 +2165,6 @@ class ChannelManager:
return channel
@utils.deprecated("Please use create_classic_channel()")
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
return await self.create_classic_channel(
connection=connection, spec=ClassicChannelSpec(psm=psm)
)
async def create_classic_channel(
self, connection: Connection, spec: ClassicChannelSpec
) -> ClassicChannel:
@@ -2244,20 +2201,3 @@ class ChannelManager:
raise e
return channel
# -----------------------------------------------------------------------------
# Deprecated Classes
# -----------------------------------------------------------------------------
class Channel(ClassicChannel):
@utils.deprecated("Please use ClassicChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class LeConnectionOrientedChannel(LeCreditBasedChannel):
@utils.deprecated("Please use LeCreditBasedChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
+5 -6
View File
@@ -21,7 +21,7 @@ import logging
import struct
from typing import Any, Callable, Optional, Union
from bumble import gatt, gatt_client, l2cap, utils
from bumble import data_types, gatt, gatt_client, l2cap, utils
from bumble.core import AdvertisingData
from bumble.device import Connection, Device
@@ -185,12 +185,11 @@ class AshaService(gatt.TemplateService):
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(gatt.GATT_ASHA_SERVICE)
+ bytes([self.protocol_version, self.capability])
data_types.ServiceData16BitUUID(
gatt.GATT_ASHA_SERVICE,
bytes([self.protocol_version, self.capability])
+ self.hisyncid[:4],
),
)
]
)
)
+8 -17
View File
@@ -27,7 +27,7 @@ from collections.abc import Sequence
from typing_extensions import Self
from bumble import core, gatt, hci, utils
from bumble import core, data_types, gatt, hci, utils
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
@@ -257,11 +257,10 @@ class UnicastServerAdvertisingData:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
data_types.ServiceData16BitUUID(
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE,
struct.pack(
'<2sBIB',
bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
'<BIB',
self.announcement_type,
self.available_audio_contexts,
len(self.metadata),
@@ -490,12 +489,8 @@ class BroadcastAudioAnnouncement:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
data_types.ServiceData16BitUUID(
gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self)
)
]
)
@@ -607,12 +602,8 @@ class BasicAudioAnnouncement:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
data_types.ServiceData16BitUUID(
gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self)
)
]
)
+57 -54
View File
@@ -18,7 +18,6 @@
from __future__ import annotations
import asyncio
import functools
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Union
@@ -272,7 +271,7 @@ class HearingAccessService(gatt.TemplateService):
def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
@@ -373,8 +372,7 @@ class HearingAccessService(gatt.TemplateService):
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
][:num_presets]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
@@ -383,7 +381,10 @@ class HearingAccessService(gatt.TemplateService):
async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
@@ -404,7 +405,7 @@ class HearingAccessService(gatt.TemplateService):
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
await self._notify_preset_operations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
@@ -413,14 +414,14 @@ class HearingAccessService(gatt.TemplateService):
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
@@ -432,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
@@ -447,8 +448,10 @@ class HearingAccessService(gatt.TemplateService):
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Preset Changed operation aborted and
# shall continue the operation when the client reconnects.
while op_list:
try:
await connection.device.indicate_subscriber(
connection,
@@ -460,14 +463,15 @@ class HearingAccessService(gatt.TemplateService):
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
for history_list in self.preset_changed_operations_history_per_device.values():
history_list.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes):
del connection # Unused
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -532,48 +536,51 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(self, _: Connection, value: bytes):
async def _on_set_active_preset(self, connection: Connection, value: bytes):
del connection # Unused
await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous):
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
'''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
presets = sorted(
[
record
for record in self.preset_records.values()
if record.is_available()
],
key=lambda record: record.index,
)
current_preset = self.preset_records[self.active_preset_index]
current_preset_pos = presets.index(current_preset)
if is_previous:
new_preset = presets[(current_preset_pos - 1) % len(presets)]
else:
new_preset = presets[(current_preset_pos + 1) % len(presets)]
if not first_preset: # If no other preset are available
if current_preset == new_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
self.active_preset_index = new_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_previous_preset(
self, connection: Connection, value: bytes
) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally(
self, _: Connection, value: bytes
self, connection: Connection, value: bytes
):
del connection # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -584,8 +591,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -596,8 +604,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -615,11 +624,13 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue
preset_control_point_indications: asyncio.Queue[bytes]
active_preset_index_notification: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
@@ -641,20 +652,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
async def setup_subscription(self) -> None:
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
self.preset_control_point_indications.put_nowait,
prefer_notify=False,
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
self.active_preset_index_notification.put_nowait
)
+16
View File
@@ -500,6 +500,22 @@ class OpenIntEnum(enum.IntEnum):
return obj
# -----------------------------------------------------------------------------
class CompatibleIntFlag(enum.IntFlag):
"""
Subclass of `enum.IntFlag` with a `composite_name` property that behaves like the
`name` property of the `enum.IntFlag` implementation for python vesions >= 3.11
"""
@property
def composite_name(self) -> str:
return '|'.join(
name
for flag in self.__class__
if self.value & flag.value and (name := flag.name) is not None
)
# -----------------------------------------------------------------------------
class ByteSerializable(Protocol):
"""
+7 -7
View File
@@ -21,6 +21,7 @@ import struct
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.profiles.battery_service import BatteryService
@@ -47,15 +48,14 @@ async def main() -> None:
device.advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Battery', 'utf-8'),
data_types.CompleteLocalName('Bumble Battery'),
data_types.IncompleteListOf16BitServiceUUIDs(
[battery_service.uuid]
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(battery_service.uuid),
data_types.Appearance(
data_types.Appearance.Category.WEARABLE_AUDIO_DEVICE,
data_types.Appearance.WearableAudioDeviceSubcategory.EARBUD,
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
+5 -4
View File
@@ -20,6 +20,7 @@ import struct
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.profiles.device_information_service import DeviceInformationService
@@ -53,11 +54,11 @@ async def main() -> None:
device.advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Device', 'utf-8'),
data_types.CompleteLocalName('Bumble Device'),
data_types.Appearance(
data_types.Appearance.Category.HEART_RATE_SENSOR,
data_types.Appearance.HeartRateSensorSubcategory.GENERIC_HEART_RATE_SENSOR,
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
+7 -7
View File
@@ -24,6 +24,7 @@ import sys
import time
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.profiles.device_information_service import DeviceInformationService
@@ -88,15 +89,14 @@ async def main() -> None:
device.advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Heart', 'utf-8'),
data_types.CompleteLocalName('Bumble Heart'),
data_types.IncompleteListOf16BitServiceUUIDs(
[heart_rate_service.uuid]
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(heart_rate_service.uuid),
data_types.Appearance(
data_types.Appearance.Category.HEART_RATE_SENSOR,
data_types.Appearance.HeartRateSensorSubcategory.GENERIC_HEART_RATE_SENSOR,
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
+11 -8
View File
@@ -23,6 +23,7 @@ import sys
import websockets
import bumble.logging
from bumble import data_types
from bumble.colors import color
from bumble.core import AdvertisingData
from bumble.device import Connection, Device, Peer
@@ -341,16 +342,18 @@ async def keyboard_device(device, command):
device.advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Keyboard', 'utf-8'),
data_types.CompleteLocalName('Bumble Keyboard'),
data_types.IncompleteListOf16BitServiceUUIDs(
[GATT_HUMAN_INTERFACE_DEVICE_SERVICE]
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE),
data_types.Appearance(
data_types.Appearance.Category.HUMAN_INTERFACE_DEVICE,
data_types.Appearance.HumanInterfaceDeviceSubcategory.KEYBOARD,
),
data_types.Flags(
AdvertisingData.Flags.LE_LIMITED_DISCOVERABLE_MODE
| AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)),
(AdvertisingData.FLAGS, bytes([0x05])),
]
)
)
+5 -1
View File
@@ -20,6 +20,7 @@ import struct
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import AdvertisingType, Device
from bumble.hci import Address
@@ -60,7 +61,10 @@ async def main() -> None:
device.scan_response_data = bytes(
AdvertisingData(
[
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
data_types.Appearance(
data_types.Appearance.Category.HEART_RATE_SENSOR,
data_types.Appearance.HeartRateSensorSubcategory.GENERIC_HEART_RATE_SENSOR,
)
]
)
)
+5 -9
View File
@@ -23,7 +23,7 @@ from typing import Optional
import websockets
import bumble.logging
from bumble import decoder, gatt
from bumble import data_types, decoder, gatt
from bumble.core import AdvertisingData
from bumble.device import AdvertisingParameters, Device
from bumble.profiles import asha
@@ -78,14 +78,10 @@ async def main() -> None:
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device.name, 'utf-8'),
),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(gatt.GATT_ASHA_SERVICE),
data_types.CompleteLocalName(device.name),
data_types.Flags(AdvertisingData.Flags(0x06)),
data_types.IncompleteListOf16BitServiceUUIDs(
[gatt.GATT_ASHA_SERVICE]
),
]
)
+8 -16
View File
@@ -20,6 +20,7 @@ import secrets
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.hci import Address
@@ -66,23 +67,14 @@ async def main() -> None:
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(f'Bumble LE Audio-{i}', 'utf-8'),
data_types.CompleteLocalName(f'Bumble LE Audio-{i}'),
data_types.Flags(
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(CoordinatedSetIdentificationService.UUID),
data_types.IncompleteListOf16BitServiceUUIDs(
[CoordinatedSetIdentificationService.UUID]
),
]
)
+8 -16
View File
@@ -19,6 +19,7 @@ import asyncio
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.profiles.hap import (
@@ -71,23 +72,14 @@ async def main() -> None:
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble HearingAccessService', 'utf-8'),
data_types.CompleteLocalName('Bumble HearingAccessService'),
data_types.Flags(
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(HearingAccessService.UUID),
data_types.IncompleteListOf16BitServiceUUIDs(
[HearingAccessService.UUID]
),
]
)
+5 -11
View File
@@ -23,6 +23,7 @@ from typing import Optional
import websockets
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import (
AdvertisingEventProperties,
@@ -106,17 +107,10 @@ async def main() -> None:
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
data_types.CompleteLocalName('Bumble LE Audio'),
data_types.Flags(AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG),
data_types.IncompleteListOf16BitServiceUUIDs(
[PublishedAudioCapabilitiesService.UUID]
),
]
)
+8 -16
View File
@@ -24,6 +24,7 @@ import struct
import sys
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.hci import CodecID, CodingFormat, HCI_IsoDataPacket
@@ -111,23 +112,14 @@ async def main() -> None:
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
data_types.CompleteLocalName('Bumble LE Audio'),
data_types.Flags(
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
data_types.IncompleteListOf16BitServiceUUIDs(
[PublishedAudioCapabilitiesService.UUID]
),
]
)
+8 -16
View File
@@ -24,6 +24,7 @@ from typing import Optional
import websockets
import bumble.logging
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import AdvertisingEventProperties, AdvertisingParameters, Device
from bumble.hci import CodecID, CodingFormat, OwnAddressType
@@ -127,23 +128,14 @@ async def main() -> None:
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
data_types.CompleteLocalName('Bumble LE Audio'),
data_types.Flags(
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
data_types.IncompleteListOf16BitServiceUUIDs(
[PublishedAudioCapabilitiesService.UUID]
),
]
)
+241 -48
View File
@@ -15,67 +15,261 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
from __future__ import annotations
import struct
from collections.abc import Sequence
import pytest
from bumble import avc, avctp, avrcp, controller, core, device, host, link
from bumble.transport import common
from bumble import avc, avctp, avrcp
from . import test_utils
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = link.LocalLink()
self.controllers = [
controller.Controller('C1', link=self.link, public_address=addresses[0]),
controller.Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
device.Device(
address=addresses[0],
host=host.Host(
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
),
),
device.Device(
address=addresses[1],
host=host.Host(
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
),
),
]
self.devices[0].classic_enabled = True
self.devices[1].classic_enabled = True
self.connections = [None, None]
self.protocols = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
async def setup_connections(self):
await self.devices[0].power_on()
await self.devices[1].power_on()
self.connections = await asyncio.gather(
self.devices[0].connect(
self.devices[1].public_address, core.PhysicalTransport.BR_EDR
),
self.devices[1].accept(self.devices[0].public_address),
)
class TwoDevices(test_utils.TwoDevices):
protocols: Sequence[avrcp.Protocol] = ()
async def setup_avdtp_connections(self):
self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
self.protocols[0].listen(self.devices[1])
await self.protocols[1].connect(self.connections[0])
@classmethod
async def create_with_avdtp(cls) -> TwoDevices:
devices = await cls.create_with_connection()
await devices.setup_avdtp_connections()
return devices
@pytest.mark.parametrize(
"command,",
[
avrcp.GetPlayStatusCommand(),
avrcp.GetCapabilitiesCommand(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID
),
avrcp.SetAbsoluteVolumeCommand(volume=5),
avrcp.GetElementAttributesCommand(
identifier=999,
attribute_ids=[
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.ARTIST_NAME,
],
),
avrcp.RegisterNotificationCommand(
event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123
),
avrcp.SearchCommand(
character_set_id=avrcp.CharacterSetId.UTF_8, search_string="Bumble!"
),
avrcp.PlayItemCommand(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
),
avrcp.ListPlayerApplicationSettingAttributesCommand(),
avrcp.ListPlayerApplicationSettingValuesCommand(
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
),
avrcp.GetCurrentPlayerApplicationSettingValueCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.SetPlayerApplicationSettingValueCommand(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
),
avrcp.GetPlayerApplicationSettingAttributeTextCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.GetPlayerApplicationSettingValueTextCommand(
attribute=avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
],
),
avrcp.InformDisplayableCharacterSetCommand(
character_set_id=[avrcp.CharacterSetId.UTF_8]
),
avrcp.InformBatteryStatusOfCtCommand(
battery_status=avrcp.InformBatteryStatusOfCtCommand.BatteryStatus.NORMAL
),
avrcp.SetAddressedPlayerCommand(player_id=1),
avrcp.SetBrowsedPlayerCommand(player_id=1),
avrcp.GetFolderItemsCommand(
scope=avrcp.Scope.NOW_PLAYING,
start_item=0,
end_item=1,
attributes=[avrcp.MediaAttributeId.ARTIST_NAME],
),
avrcp.ChangePathCommand(
uid_counter=1,
direction=avrcp.ChangePathCommand.Direction.DOWN,
folder_uid=2,
),
avrcp.GetItemAttributesCommand(
scope=avrcp.Scope.NOW_PLAYING,
uid=0,
uid_counter=1,
start_item=0,
end_item=0,
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
),
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
avrcp.AddToNowPlayingCommand(
scope=avrcp.Scope.NOW_PLAYING, uid=0, uid_counter=1
),
],
)
def test_command(command: avrcp.Command):
assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command
@pytest.mark.parametrize(
"event,",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(identifier=b'12356'),
avrcp.VolumeChangedEvent(volume=9),
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
avrcp.AddressedPlayerChangedEvent(
player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10)
),
avrcp.AvailablePlayersChangedEvent(),
avrcp.PlaybackPositionChangedEvent(playback_position=1314),
avrcp.NowPlayingContentChangedEvent(),
avrcp.PlayerApplicationSettingChangedEvent(
player_application_settings=[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
)
]
),
],
)
def test_event(event: avrcp.Event):
assert avrcp.Event.from_bytes(bytes(event)) == event
@pytest.mark.parametrize(
"response,",
[
avrcp.GetPlayStatusResponse(
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
),
avrcp.GetCapabilitiesResponse(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED,
capabilities=[
avrcp.EventId.ADDRESSED_PLAYER_CHANGED,
avrcp.EventId.BATT_STATUS_CHANGED,
],
),
avrcp.RegisterNotificationResponse(
event=avrcp.PlaybackPositionChangedEvent(playback_position=38)
),
avrcp.SetAbsoluteVolumeResponse(volume=99),
avrcp.GetElementAttributesResponse(
attributes=[
avrcp.MediaAttribute(
attribute_id=avrcp.MediaAttributeId.ALBUM_NAME,
attribute_value="White Album",
character_set_id=avrcp.CharacterSetId.UTF_8,
)
]
),
avrcp.ListPlayerApplicationSettingAttributesResponse(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
),
avrcp.ListPlayerApplicationSettingValuesResponse(
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
]
),
avrcp.GetCurrentPlayerApplicationSettingValueResponse(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
),
avrcp.SetPlayerApplicationSettingValueResponse(),
avrcp.GetPlayerApplicationSettingAttributeTextResponse(
attribute=[avrcp.ApplicationSetting.AttributeId.REPEAT_MODE],
character_set_id=[avrcp.CharacterSetId.UTF_8],
attribute_string=["Repeat"],
),
avrcp.GetPlayerApplicationSettingValueTextResponse(
value=[avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT],
character_set_id=[avrcp.CharacterSetId.UTF_8],
attribute_string=["All track repeat"],
),
avrcp.InformDisplayableCharacterSetResponse(),
avrcp.InformBatteryStatusOfCtResponse(),
avrcp.SetAddressedPlayerResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
avrcp.SetBrowsedPlayerResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
numbers_of_items=2,
character_set_id=avrcp.CharacterSetId.UTF_8,
folder_names=["folder1", "folder2"],
),
avrcp.GetFolderItemsResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
items=[
avrcp.MediaPlayerItem(
player_id=1,
major_player_type=avrcp.MediaPlayerItem.MajorPlayerType.AUDIO,
player_sub_type=avrcp.MediaPlayerItem.PlayerSubType.AUDIO_BOOK,
play_status=avrcp.PlayStatus.FWD_SEEK,
feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Woo",
)
],
),
avrcp.ChangePathResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED, number_of_items=2
),
avrcp.GetItemAttributesResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
attribute_value_entry_list=[
avrcp.AttributeValueEntry(
attribute_id=avrcp.MediaAttributeId.GENRE,
character_set_id=avrcp.CharacterSetId.UTF_8,
attribute_value="uuddlrlrabab",
)
],
),
avrcp.GetTotalNumberOfItemsResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
number_of_items=2,
),
avrcp.SearchResponse(
status=avrcp.StatusCode.OPERATION_COMPLETED,
uid_counter=1,
number_of_items=2,
),
avrcp.PlayItemResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
avrcp.AddToNowPlayingResponse(status=avrcp.StatusCode.OPERATION_COMPLETED),
],
)
def test_response(response: avrcp.Response):
assert avrcp.Response.from_bytes(bytes(response), response.pdu_id) == response
# -----------------------------------------------------------------------------
def test_frame_parser():
with pytest.raises(ValueError) as error:
with pytest.raises(ValueError):
avc.Frame.from_bytes(bytes.fromhex("11480000"))
x = bytes.fromhex("014D0208")
@@ -217,8 +411,7 @@ def test_passthrough_commands():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_events():
two_devices = TwoDevices()
await two_devices.setup_connections()
two_devices = await TwoDevices.create_with_avdtp()
supported_events = await two_devices.protocols[0].get_supported_events()
assert supported_events == []
+25 -1
View File
@@ -16,7 +16,13 @@
# Imports
# -----------------------------------------------------------------------------
from bumble.core import UUID, AdvertisingData, Appearance, get_dict_key_by_value
from bumble.core import (
UUID,
AdvertisingData,
Appearance,
ClassOfDevice,
get_dict_key_by_value,
)
# -----------------------------------------------------------------------------
@@ -93,6 +99,24 @@ def test_appearance() -> None:
assert int(a) == 0x3333
# -----------------------------------------------------------------------------
def test_class_of_device() -> None:
c1 = ClassOfDevice(
ClassOfDevice.MajorServiceClasses.AUDIO
| ClassOfDevice.MajorServiceClasses.RENDERING,
ClassOfDevice.MajorDeviceClass.AUDIO_VIDEO,
ClassOfDevice.AudioVideoMinorDeviceClass.CAMCORDER,
)
assert str(c1) == "ClassOfDevice(RENDERING|AUDIO,AUDIO_VIDEO/CAMCORDER)"
c2 = ClassOfDevice(
ClassOfDevice.MajorServiceClasses.AUDIO,
ClassOfDevice.MajorDeviceClass.AUDIO_VIDEO,
0x123,
)
assert str(c2) == "ClassOfDevice(AUDIO,AUDIO_VIDEO/0x123)"
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ad_data()
+10 -15
View File
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import struct
from bumble import data_types
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.hci import HCI_Reset_Command
@@ -65,24 +66,18 @@ class HeartRateMonitor:
self.device.advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_NOT_SUPPORTED_FLAG
]
),
data_types.Flags(
AdvertisingData.Flags.LE_GENERAL_DISCOVERABLE_MODE
| AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
),
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Heart', 'utf-8'),
data_types.CompleteLocalName('Bumble Heart'),
data_types.IncompleteListOf16BitServiceUUIDs(
[self.heart_rate_service.uuid]
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(self.heart_rate_service.uuid),
data_types.Appearance(
data_types.Appearance.Category.HEART_RATE_SENSOR,
data_types.Appearance.HeartRateSensorSubcategory.GENERIC_HEART_RATE_SENSOR,
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)