Compare commits

..

3 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
b2893f26b6 fix types 2026-03-06 18:23:20 -08:00
Gilles Boccon-Gibod
90560cdea1 revert libusb-package version change 2026-03-06 17:54:03 -08:00
Gilles Boccon-Gibod
794a4a3ef0 add basic support for SCO packets over USB 2026-03-03 10:28:56 -08:00
46 changed files with 3026 additions and 3084 deletions

View File

@@ -69,7 +69,7 @@ jobs:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features --version 1.11.0 --locked # allows building/testing combinations of features
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build

View File

@@ -24,18 +24,13 @@ import dataclasses
import functools
import logging
import secrets
import sys
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from typing import (
Any,
)
import click
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
import tomli
try:
import lc3 # type: ignore # pylint: disable=E0401
@@ -119,7 +114,7 @@ def parse_broadcast_list(filename: str) -> Sequence[Broadcast]:
broadcasts: list[Broadcast] = []
with open(filename, "rb") as config_file:
config = tomllib.load(config_file)
config = tomli.load(config_file)
for broadcast in config.get("broadcasts", []):
sources = []
for source in broadcast.get("sources", []):

View File

@@ -45,8 +45,10 @@ from bumble.hci import (
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Read_Voice_Setting_Command,
LeFeature,
SpecificationVersion,
VoiceSetting,
map_null_terminated_utf8_string,
)
from bumble.host import Host
@@ -214,6 +216,16 @@ async def get_codecs_info(host: Host) -> None:
if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Voice_Setting_Command.op_code):
response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command())
voice_setting = VoiceSetting.from_int(response3.voice_setting)
print(color('Voice Setting:', 'yellow'))
print(f' Air Coding Format: {voice_setting.air_coding_format.name}')
print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}')
print(f' Input Sample Size: {voice_setting.input_sample_size.name}')
print(f' Input Data Format: {voice_setting.input_data_format.name}')
print(f' Input Coding Format: {voice_setting.input_coding_format.name}')
# -----------------------------------------------------------------------------
async def async_main(

View File

@@ -16,6 +16,8 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import statistics
import struct
import time
import click
@@ -25,7 +27,9 @@ from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
Address,
HCI_Read_Loopback_Mode_Command,
HCI_SynchronousDataPacket,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
@@ -36,34 +40,59 @@ from bumble.transport import open_transport
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
def __init__(
self,
packet_size: int,
packet_count: int,
connection_type: str,
mode: str,
interval: int,
transport: str,
):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: int | None = None
self.connection_type = connection_type
self.connection_event = asyncio.Event()
self.mode = mode
self.interval = interval
self.done = asyncio.Event()
self.expected_cid = 0
self.expected_counter = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
self.send_timestamps: list[float] = []
self.rtts: list[float] = []
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
# The first connection handle is of type ACL,
# subsequent connections are of type SCO
if self.connection_type == "sco" and self.connection_handle is None:
self.connection_handle = connection_handle
return
self.connection_handle = connection_handle
self.connection_event.set()
def on_sco_connection(
self, address: Address, connection_handle: int, link_type: int
):
self.on_connection(connection_handle)
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
(counter,) = struct.unpack_from("H", pdu, 0)
rtt = now - self.send_timestamps[counter]
self.rtts.append(rtt)
print(f'<<< Received packet {counter}: {len(pdu)} bytes, RTT={rtt:.4f}')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
assert counter == self.expected_counter
self.expected_counter += 1
if counter == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
@@ -71,20 +100,52 @@ class Loopback:
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
if self.mode == 'throughput':
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f},',
'cyan',
)
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
if self.expected_counter == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
def on_sco_packet(self, connection_handle: int, packet) -> None:
print("---", connection_handle, packet)
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_l2cap_pdu(self.connection_handle, 0, packet)
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_hci_packet(
HCI_SynchronousDataPacket(
connection_handle=self.connection_handle,
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
data_total_length=len(packet),
data=packet,
)
)
async def send_loop(self, host: Host, sender) -> None:
for counter in range(0, self.packet_count):
print(
color(
f'>>> Sending {self.connection_type.upper()} '
f'packet {counter}: {self.packet_size} bytes',
'yellow',
)
)
self.send_timestamps.append(time.time())
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
@@ -126,8 +187,11 @@ class Loopback:
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('classic_connection', self.on_connection)
host.on('le_connection', self.on_connection)
host.on('sco_connection', self.on_sco_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
host.on('sco_packet', self.on_sco_packet)
loopback_mode = LoopbackMode.LOCAL
@@ -148,32 +212,37 @@ class Loopback:
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
if self.connection_type == "acl":
sender = self.send_acl_packet
elif self.connection_type == "sco":
sender = self.send_sco_packet
else:
raise ValueError(f'Unknown connection type: {self.connection_type}')
await self.send_loop(host, sender)
await self.done.wait()
print(color('=== Done!', 'magenta'))
bytes_sent = self.packet_size * self.packet_count
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
if self.mode == 'throughput':
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} '
f'({bytes_sent} bytes in {elapsed:.2f} seconds)',
'green',
)
)
if self.mode == 'rtt':
print(
color(
f'RTTs: min={min(self.rtts):.4f}, '
f'max={max(self.rtts):.4f}, '
f'avg={statistics.mean(self.rtts):.4f}',
'blue',
)
)
)
# -----------------------------------------------------------------------------
@@ -194,11 +263,43 @@ class Loopback:
default=10,
help='Packet count',
)
@click.option(
'--connection-type',
'-t',
metavar='TYPE',
type=click.Choice(['acl', 'sco']),
default='acl',
help='Connection type',
)
@click.option(
'--mode',
'-m',
metavar='MODE',
type=click.Choice(['throughput', 'rtt']),
default='throughput',
help='Test mode',
)
@click.option(
'--interval',
type=int,
default=100,
help='Inter-packet interval (ms) [RTT mode only]',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
def main(packet_size, packet_count, connection_type, mode, interval, transport):
bumble.logging.setup_basic_logging()
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
if connection_type == "sco" and packet_size > 255:
print("ERROR: the maximum packet size for SCO is 255")
return
async def run():
loopback = Loopback(
packet_size, packet_count, connection_type, mode, interval, transport
)
await loopback.run()
asyncio.run(run())
# -----------------------------------------------------------------------------

View File

@@ -20,12 +20,11 @@ from __future__ import annotations
import asyncio
import logging
import os
from typing import ClassVar
import click
from prompt_toolkit.shortcuts import PromptSession
from bumble import data_types, smp
from bumble import data_types
from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
@@ -41,7 +40,7 @@ from bumble.core import (
PhysicalTransport,
ProtocolError,
)
from bumble.device import Connection, Device, Peer
from bumble.device import Device, Peer
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -54,6 +53,7 @@ from bumble.hci import OwnAddressType
from bumble.keys import JsonKeyStore
from bumble.pairing import OobData, PairingConfig, PairingDelegate
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -65,7 +65,7 @@ POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance: ClassVar[Waiter | None] = None
instance: Waiter | None = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
@@ -319,13 +319,12 @@ async def on_classic_pairing(connection):
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode):
async def on_pairing_failure(connection, reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {reason.name}', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
if Waiter.instance:
Waiter.instance.terminate()
Waiter.instance.terminate()
# -----------------------------------------------------------------------------

View File

@@ -111,9 +111,14 @@ def show_device_details(device):
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
endpoint_details = (
f', Max Packet Size = {endpoint.getMaxPacketSize()}'
if endpoint_type == 'ISOCHRONOUS'
else ''
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
f'{endpoint_type} {endpoint_direction}{endpoint_details}'
)

View File

@@ -88,6 +88,13 @@ SBC_DUAL_CHANNEL_MODE = 0x01
SBC_STEREO_CHANNEL_MODE = 0x02
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
SBC_CHANNEL_MODE_NAMES = {
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
}
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
SBC_SUBBANDS = [4, 8]
@@ -95,6 +102,11 @@ SBC_SUBBANDS = [4, 8]
SBC_SNR_ALLOCATION_METHOD = 0x00
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
SBC_ALLOCATION_METHOD_NAMES = {
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
@@ -117,6 +129,13 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
@@ -248,27 +267,26 @@ class MediaCodecInformation:
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
match media_codec_type:
case CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
case CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
case CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod

View File

@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string."""
tokens: list[bytearray] = []
tokens = []
in_quotes = False
token = bytearray()
for b in buffer:
@@ -40,24 +40,23 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
tokens.append(token[1:-1])
token = bytearray()
else:
match char:
case b' ':
pass
case b',' | b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
case b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
case b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
case _:
token.extend(char)
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
@@ -72,19 +71,18 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
current: bytes | list = b''
for token in tokens:
match token:
case b',':
accumulator[-1].append(current)
current = b''
case b'(':
accumulator.append([])
case b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
case _:
current = token
if token == b',':
accumulator[-1].append(current)
current = b''
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:

View File

@@ -954,13 +954,12 @@ class Attribute(utils.EventEmitter, Generic[_T]):
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
match attribute_type:
case str():
self.type = UUID(attribute_type)
case bytes():
self.type = UUID.from_bytes(attribute_type)
case _:
self.type = attribute_type
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
self.value = value
@@ -995,31 +994,30 @@ class Attribute(utils.EventEmitter, Generic[_T]):
)
value: _T | None
match self.value:
case AttributeValue():
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
value = self.value
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
@@ -1051,27 +1049,26 @@ class Attribute(utils.EventEmitter, Generic[_T]):
decoded_value = self.decode_value(value)
match self.value:
case AttributeValue():
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
self.value = decoded_value
if isinstance(self.value, AttributeValue):
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value
self.emit(self.EVENT_WRITE, connection, decoded_value)

View File

@@ -22,14 +22,7 @@ import enum
import functools
import logging
import struct
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterable,
Mapping,
Sequence,
)
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass, field
from typing import ClassVar, SupportsBytes, TypeVar
@@ -1056,9 +1049,11 @@ class GetItemAttributesCommand(Command):
scope: Scope = field(metadata=Scope.type_metadata(1))
uid: int = field(metadata=_UINT64_BE_METADATA)
uid_counter: int = field(metadata=hci.metadata('>2'))
start_item: int = field(metadata=hci.metadata('>4'))
end_item: int = field(metadata=hci.metadata('>4'))
# When attributes is empty, all attributes will be requested.
attributes: Sequence[MediaAttributeId] = field(
metadata=MediaAttributeId.type_metadata(4, list_begin=True, list_end=True)
metadata=MediaAttributeId.type_metadata(1, list_begin=True, list_end=True)
)
@@ -1517,9 +1512,7 @@ class PlaybackPositionChangedEvent(Event):
@dataclass
class TrackChangedEvent(Event):
event_id = EventId.TRACK_CHANGED
NO_TRACK = 0xFFFFFFFFFFFFFFFF
uid: int = field(metadata=_UINT64_BE_METADATA)
identifier: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@@ -1543,19 +1536,16 @@ class PlayerApplicationSettingChangedEvent(Event):
def __post_init__(self) -> None:
super().__post_init__()
match self.attribute_id:
case ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(
self.value_id
)
case ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
case ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
case ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
case _:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
else:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
player_application_settings: Sequence[Setting] = field(
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
@@ -1629,8 +1619,6 @@ class Delegate:
supported_events: list[EventId]
supported_company_ids: list[int]
supported_player_app_settings: dict[ApplicationSetting.AttributeId, list[int]]
player_app_settings: dict[ApplicationSetting.AttributeId, int]
volume: int
playback_status: PlayStatus
@@ -1638,23 +1626,11 @@ class Delegate:
self,
supported_events: Iterable[EventId] = (),
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
supported_player_app_settings: (
Mapping[ApplicationSetting.AttributeId, Sequence[int]] | None
) = None,
) -> None:
self.supported_company_ids = list(supported_company_ids)
self.supported_events = list(supported_events)
self.volume = 0
self.playback_status = PlayStatus.STOPPED
self.supported_player_app_settings = (
{key: list(value) for key, value in supported_player_app_settings.items()}
if supported_player_app_settings
else {}
)
self.player_app_settings = {}
self.uid_counter = 0
self.addressed_player_id = 0
self.current_track_uid = TrackChangedEvent.NO_TRACK
async def get_supported_events(self) -> list[EventId]:
return self.supported_events
@@ -1687,38 +1663,6 @@ class Delegate:
async def get_playback_status(self) -> PlayStatus:
return self.playback_status
async def get_supported_player_app_settings(
self,
) -> dict[ApplicationSetting.AttributeId, list[int]]:
return self.supported_player_app_settings
async def get_current_player_app_settings(
self,
) -> dict[ApplicationSetting.AttributeId, int]:
return self.player_app_settings
async def set_player_app_settings(
self, attribute: ApplicationSetting.AttributeId, value: int
) -> None:
self.player_app_settings[attribute] = value
async def play_item(self, scope: Scope, uid: int, uid_counter: int) -> None:
logger.debug(
"@@@ play_item: scope=%s, uid=%s, uid_counter=%s",
scope,
uid,
uid_counter,
)
async def get_uid_counter(self) -> int:
return self.uid_counter
async def get_addressed_player_id(self) -> int:
return self.addressed_player_id
async def get_current_track_uid(self) -> int:
return self.current_track_uid
# TODO add other delegate methods
@@ -1966,51 +1910,6 @@ class Protocol(utils.EventEmitter):
response = self._check_response(response_context, GetElementAttributesResponse)
return list(response.attributes)
async def list_supported_player_app_settings(
self, attribute_ids: Sequence[ApplicationSetting.AttributeId] = ()
) -> dict[ApplicationSetting.AttributeId, list[int]]:
"""Get element attributes from the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
ListPlayerApplicationSettingAttributesCommand(),
)
if not attribute_ids:
list_attribute_response = self._check_response(
response_context, ListPlayerApplicationSettingAttributesResponse
)
attribute_ids = list_attribute_response.attribute
supported_settings: dict[ApplicationSetting.AttributeId, list[int]] = {}
for attribute_id in attribute_ids:
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
ListPlayerApplicationSettingValuesCommand(attribute_id),
)
list_value_response = self._check_response(
response_context, ListPlayerApplicationSettingValuesResponse
)
supported_settings[attribute_id] = list(list_value_response.value)
return supported_settings
async def get_player_app_settings(
self, attribute_ids: Sequence[ApplicationSetting.AttributeId]
) -> dict[ApplicationSetting.AttributeId, int]:
"""Get element attributes from the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
GetCurrentPlayerApplicationSettingValueCommand(attribute_ids),
)
response: GetCurrentPlayerApplicationSettingValueResponse = (
self._check_response(
response_context, GetCurrentPlayerApplicationSettingValueResponse
)
)
return {
attribute_id: value
for attribute_id, value in zip(response.attribute, response.value)
}
async def monitor_events(
self, event_id: EventId, playback_interval: int = 0
) -> AsyncIterator[Event]:
@@ -2062,13 +1961,13 @@ class Protocol(utils.EventEmitter):
async def monitor_track_changed(
self,
) -> AsyncIterator[int]:
) -> AsyncIterator[bytes]:
"""Monitor Track changes from the connected peer."""
async for event in self.monitor_events(EventId.TRACK_CHANGED, 0):
if not isinstance(event, TrackChangedEvent):
logger.warning("unexpected event class")
continue
yield event.uid
yield event.identifier
async def monitor_playback_position(
self, playback_interval: int
@@ -2161,9 +2060,11 @@ class Protocol(utils.EventEmitter):
"""Notify the connected peer of a Playback Status change."""
self.notify_event(PlaybackStatusChangedEvent(status))
def notify_track_changed(self, uid: int) -> None:
def notify_track_changed(self, identifier: bytes) -> None:
"""Notify the connected peer of a Track change."""
self.notify_event(TrackChangedEvent(uid))
if len(identifier) != 8:
raise core.InvalidArgumentError("identifier must be 8 bytes")
self.notify_event(TrackChangedEvent(identifier))
def notify_playback_position_changed(self, position: int) -> None:
"""Notify the connected peer of a Position change."""
@@ -2379,40 +2280,21 @@ class Protocol(utils.EventEmitter):
):
# TODO: catch exceptions from delegates
command = Command.from_bytes(pdu_id, pdu)
match command:
case GetCapabilitiesCommand():
self._on_get_capabilities_command(transaction_label, command)
case SetAbsoluteVolumeCommand():
self._on_set_absolute_volume_command(transaction_label, command)
case RegisterNotificationCommand():
self._on_register_notification_command(transaction_label, command)
case GetPlayStatusCommand():
self._on_get_play_status_command(transaction_label, command)
case ListPlayerApplicationSettingAttributesCommand():
self._on_list_player_application_setting_attributes_command(
transaction_label, command
)
case ListPlayerApplicationSettingValuesCommand():
self._on_list_player_application_setting_values_command(
transaction_label, command
)
case SetPlayerApplicationSettingValueCommand():
self._on_set_player_application_setting_value_command(
transaction_label, command
)
case GetCurrentPlayerApplicationSettingValueCommand():
self._on_get_current_player_application_setting_value_command(
transaction_label, command
)
case PlayItemCommand():
self._on_play_item_command(transaction_label, command)
case _:
# Not supported.
# TODO: check that this is the right way to respond in this case.
logger.debug("unsupported PDU ID")
self.send_rejected_avrcp_response(
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
)
if isinstance(command, GetCapabilitiesCommand):
self._on_get_capabilities_command(transaction_label, command)
elif isinstance(command, SetAbsoluteVolumeCommand):
self._on_set_absolute_volume_command(transaction_label, command)
elif isinstance(command, RegisterNotificationCommand):
self._on_register_notification_command(transaction_label, command)
elif isinstance(command, GetPlayStatusCommand):
self._on_get_play_status_command(transaction_label, command)
else:
# Not supported.
# TODO: check that this is the right way to respond in this case.
logger.debug("unsupported PDU ID")
self.send_rejected_avrcp_response(
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
)
else:
logger.debug("unsupported command type")
self.send_rejected_avrcp_response(
@@ -2440,29 +2322,26 @@ class Protocol(utils.EventEmitter):
# is Ok, but if/when more responses are supported, a lookup mechanism would be
# more appropriate.
response: Response | None = None
match response_code:
case avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(
pdu_id=pdu_id, status_code=StatusCode(pdu[0])
)
case avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
case (
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE
| avc.ResponseFrame.ResponseCode.INTERIM
| avc.ResponseFrame.ResponseCode.CHANGED
| avc.ResponseFrame.ResponseCode.ACCEPTED
):
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
case _:
logger.debug("unexpected response code")
pending_command.response.set_exception(
core.ProtocolError(
error_code=None,
error_namespace="avrcp",
details="unexpected response code",
)
if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(pdu_id=pdu_id, status_code=StatusCode(pdu[0]))
elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
elif response_code in (
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
avc.ResponseFrame.ResponseCode.INTERIM,
avc.ResponseFrame.ResponseCode.CHANGED,
avc.ResponseFrame.ResponseCode.ACCEPTED,
):
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
else:
logger.debug("unexpected response code")
pending_command.response.set_exception(
core.ProtocolError(
error_code=None,
error_namespace="avrcp",
details="unexpected response code",
)
)
if response is None:
self.recycle_pending_command(pending_command)
@@ -2633,18 +2512,22 @@ class Protocol(utils.EventEmitter):
async def get_supported_events() -> None:
capabilities: Sequence[bytes | SupportsBytes]
match command.capability_id:
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
capabilities = await self.delegate.get_supported_events()
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED.COMPANY_ID:
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
case _:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
if (
command.capability_id
== GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
):
capabilities = await self.delegate.get_supported_events()
elif (
command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID
):
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
else:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
@@ -2689,121 +2572,6 @@ class Protocol(utils.EventEmitter):
self._delegate_command(transaction_label, command, get_playback_status())
def _on_list_player_application_setting_attributes_command(
self,
transaction_label: int,
command: ListPlayerApplicationSettingAttributesCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
supported_settings = await self.delegate.get_supported_player_app_settings()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
ListPlayerApplicationSettingAttributesResponse(
list(supported_settings.keys())
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_list_player_application_setting_values_command(
self,
transaction_label: int,
command: ListPlayerApplicationSettingValuesCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
supported_settings = await self.delegate.get_supported_player_app_settings()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
ListPlayerApplicationSettingValuesResponse(
supported_settings.get(command.attribute, [])
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_get_current_player_application_setting_value_command(
self,
transaction_label: int,
command: GetCurrentPlayerApplicationSettingValueCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
current_settings = await self.delegate.get_current_player_app_settings()
if not all(
attribute in current_settings for attribute in command.attribute
):
self.send_not_implemented_avrcp_response(
transaction_label,
PduId.GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE,
)
return
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetCurrentPlayerApplicationSettingValueResponse(
attribute=command.attribute,
value=[
current_settings[attribute] for attribute in command.attribute
],
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_set_player_application_setting_value_command(
self,
transaction_label: int,
command: SetPlayerApplicationSettingValueCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def set_player_app_settings() -> None:
for attribute, value in zip(command.attribute, command.value):
await self.delegate.set_player_app_settings(attribute, value)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
SetPlayerApplicationSettingValueResponse(),
)
self._delegate_command(transaction_label, command, set_player_app_settings())
def _on_play_item_command(
self,
transaction_label: int,
command: PlayItemCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def play_item() -> None:
await self.delegate.play_item(
scope=command.scope, uid=command.uid, uid_counter=command.uid_counter
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
PlayItemResponse(status=StatusCode.OPERATION_COMPLETED),
)
self._delegate_command(transaction_label, command, play_item())
def _on_register_notification_command(
self, transaction_label: int, command: RegisterNotificationCommand
) -> None:
@@ -2819,51 +2587,26 @@ class Protocol(utils.EventEmitter):
)
return
event: Event
match command.event_id:
case EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
event = VolumeChangedEvent(volume)
case EventId.PLAYBACK_STATUS_CHANGED:
playback_status = await self.delegate.get_playback_status()
event = PlaybackStatusChangedEvent(play_status=playback_status)
case EventId.NOW_PLAYING_CONTENT_CHANGED:
event = NowPlayingContentChangedEvent()
case EventId.PLAYER_APPLICATION_SETTING_CHANGED:
settings = await self.delegate.get_current_player_app_settings()
event = PlayerApplicationSettingChangedEvent(
[
PlayerApplicationSettingChangedEvent.Setting(
attribute, value # type: ignore
)
for attribute, value in settings.items()
]
)
case EventId.AVAILABLE_PLAYERS_CHANGED:
event = AvailablePlayersChangedEvent()
case EventId.ADDRESSED_PLAYER_CHANGED:
event = AddressedPlayerChangedEvent(
AddressedPlayerChangedEvent.Player(
player_id=await self.delegate.get_addressed_player_id(),
uid_counter=await self.delegate.get_uid_counter(),
)
)
case EventId.UIDS_CHANGED:
event = UidsChangedEvent(await self.delegate.get_uid_counter())
case EventId.TRACK_CHANGED:
event = TrackChangedEvent(
await self.delegate.get_current_track_uid()
)
case _:
logger.warning(
"Event supported but not handled %s", command.event_id
)
return
response: Response
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
elif command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=playback_status)
)
elif command.event_id == EventId.NOW_PLAYING_CONTENT_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(NowPlayingContentChangedEvent())
else:
logger.warning("Event supported but not handled %s", command.event_id)
return
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
RegisterNotificationResponse(event),
response,
)
self._register_notification_listener(transaction_label, command)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,6 @@ from __future__ import annotations
import dataclasses
import enum
import functools
import struct
from collections.abc import Iterable
from typing import (
@@ -274,18 +273,6 @@ class UUID:
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
@functools.cached_property
def uuid_128_bytes(self) -> bytes:
match len(self.uuid_bytes):
case 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
case 4:
return self.BASE_UUID + self.uuid_bytes
case 16:
return self.uuid_bytes
case _:
assert False, "unreachable"
def to_bytes(self, force_128: bool = False) -> bytes:
'''
Serialize UUID in little-endian byte-order
@@ -293,7 +280,14 @@ class UUID:
if not force_128:
return self.uuid_bytes
return self.uuid_128_bytes
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
else:
assert False, "unreachable"
def to_pdu_bytes(self) -> bytes:
'''
@@ -323,7 +317,7 @@ class UUID:
def __eq__(self, other: object) -> bool:
if isinstance(other, UUID):
return self.uuid_128_bytes == other.uuid_128_bytes
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
if isinstance(other, str):
return UUID(other) == self
@@ -331,7 +325,7 @@ class UUID:
return False
def __hash__(self) -> int:
return hash(self.uuid_128_bytes)
return hash(self.uuid_bytes)
def __str__(self) -> str:
result = self.to_hex_str(separator='-')
@@ -1775,71 +1769,66 @@ class AdvertisingData:
@classmethod
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
match ad_type:
case AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
case AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
if ad_type == AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
case AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
case AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
)
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
case AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(
struct.unpack_from('<H', ad_data, 0)[0]
)
ad_data_str = str(appearance)
case AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
case _:
ad_type_str = AdvertisingData.Type(ad_type).name
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
ad_data_str = str(appearance)
elif ad_type == AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
else:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}'
@@ -2116,10 +2105,13 @@ class AdvertisingData:
# -----------------------------------------------------------------------------
# Connection PHY
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ConnectionPHY:
tx_phy: int
rx_phy: int
def __init__(self, tx_phy, rx_phy):
self.tx_phy = tx_phy
self.rx_phy = rx_phy
def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
# -----------------------------------------------------------------------------

View File

@@ -1423,6 +1423,9 @@ class ScoLink(utils.CompositeEventEmitter):
acl_connection: Connection
handle: int
link_type: int
rx_packet_length: int
tx_packet_length: int
air_mode: hci.CodecID
sink: Callable[[hci.HCI_SynchronousDataPacket], Any] | None = None
EVENT_DISCONNECTION: ClassVar[str] = "disconnection"
@@ -1837,7 +1840,6 @@ class Connection(utils.CompositeEventEmitter):
self.pairing_peer_io_capability = None
self.pairing_peer_authentication_requirements = None
self.peer_le_features = hci.LeFeatureMask(0)
self.peer_classic_features = hci.LmpFeatureMask(0)
self.cs_configs = {}
self.cs_procedures = {}
@@ -2055,15 +2057,6 @@ class Connection(utils.CompositeEventEmitter):
self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features
async def get_remote_classic_features(self) -> hci.LmpFeatureMask:
"""[Classic Only] Reads remote LMP supported features.
Returns:
LMP features supported by the remote device.
"""
self.peer_classic_features = await self.device.get_remote_classic_features(self)
return self.peer_classic_features
def on_att_mtu_update(self, mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
@@ -2159,7 +2152,6 @@ class DeviceConfiguration:
)
eatt_enabled: bool = False
gatt_services: list[dict[str, Any]] = field(init=False)
smp_debug_mode: bool = False
def __post_init__(self) -> None:
self.gatt_services = []
@@ -2572,7 +2564,6 @@ class Device(utils.CompositeEventEmitter):
),
),
)
self.smp_manager.debug_mode = self.config.smp_debug_mode
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -5293,77 +5284,6 @@ class Device(utils.CompositeEventEmitter):
)
return await read_feature_future
async def get_remote_classic_features(
self, connection: Connection
) -> hci.LmpFeatureMask:
"""[Classic Only] Reads remote LE supported features.
Args:
handle: connection handle to read LMP features.
Returns:
LMP features supported by the remote device.
"""
with closing(utils.EventWatcher()) as watcher:
read_feature_future: asyncio.Future[tuple[int, int]] = (
asyncio.get_running_loop().create_future()
)
read_features = hci.LmpFeatureMask(0)
current_page_number = 0
@watcher.on(self.host, 'classic_remote_features')
def on_classic_remote_features(
handle: int,
status: int,
features: int,
page_number: int,
max_page_number: int,
) -> None:
if handle != connection.handle:
logger.warning(
"Received classic_remote_features for wrong handle, expected=0x%04X, got=0x%04X",
connection.handle,
handle,
)
return
if page_number != current_page_number:
logger.warning(
"Received classic_remote_features for wrong page, expected=%d, got=%d",
current_page_number,
page_number,
)
return
if status == hci.HCI_ErrorCode.SUCCESS:
read_feature_future.set_result((features, max_page_number))
else:
read_feature_future.set_exception(hci.HCI_Error(status))
await self.send_async_command(
hci.HCI_Read_Remote_Supported_Features_Command(
connection_handle=connection.handle
)
)
new_features, max_page_number = await read_feature_future
read_features |= new_features
if not (read_features & hci.LmpFeatureMask.EXTENDED_FEATURES):
return read_features
while current_page_number <= max_page_number:
read_feature_future = asyncio.get_running_loop().create_future()
await self.send_async_command(
hci.HCI_Read_Remote_Extended_Features_Command(
connection_handle=connection.handle,
page_number=current_page_number,
)
)
new_features, max_page_number = await read_feature_future
read_features |= new_features << (current_page_number * 64)
current_page_number += 1
return read_features
@utils.experimental('Only for testing.')
async def get_remote_cs_capabilities(
self, connection: Connection
@@ -6051,7 +5971,7 @@ class Device(utils.CompositeEventEmitter):
def on_connection_request(
self, bd_addr: hci.Address, class_of_device: int, link_type: int
):
logger.debug(f'*** Connection request: {bd_addr}')
logger.debug(f'*** Connection request: {bd_addr} link_type={link_type}')
# Handle SCO request.
if link_type in (
@@ -6061,6 +5981,7 @@ class Device(utils.CompositeEventEmitter):
if connection := self.find_connection_by_bd_addr(
bd_addr, transport=PhysicalTransport.BR_EDR
):
connection.emit(self.EVENT_SCO_REQUEST, link_type)
self.emit(self.EVENT_SCO_REQUEST, connection, link_type)
else:
logger.error(f'SCO request from a non-connected device {bd_addr}')
@@ -6420,8 +6341,7 @@ class Device(utils.CompositeEventEmitter):
logger.warning('peer name is not valid UTF-8')
if connection:
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
else:
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
# [Classic only]
@host_event_handler
@@ -6438,7 +6358,13 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_address
@utils.experimental('Only for testing.')
def on_sco_connection(
self, acl_connection: Connection, sco_handle: int, link_type: int
self,
acl_connection: Connection,
sco_handle: int,
link_type: int,
rx_packet_length: int,
tx_packet_length: int,
air_mode: int,
) -> None:
logger.debug(
f'*** SCO connected: {acl_connection.peer_address}, '
@@ -6450,7 +6376,11 @@ class Device(utils.CompositeEventEmitter):
acl_connection=acl_connection,
handle=sco_handle,
link_type=link_type,
rx_packet_length=rx_packet_length,
tx_packet_length=tx_packet_length,
air_mode=hci.CodecID(air_mode),
)
acl_connection.emit(self.EVENT_SCO_CONNECTION, sco_link)
self.emit(self.EVENT_SCO_CONNECTION, sco_link)
# [Classic only]
@@ -6461,7 +6391,8 @@ class Device(utils.CompositeEventEmitter):
self, acl_connection: Connection, status: int
) -> None:
logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***')
self.emit(self.EVENT_SCO_CONNECTION_FAILURE)
acl_connection.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
self.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
# [Classic only]
@host_event_handler
@@ -6924,15 +6855,18 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_address
def on_classic_pairing(self, connection: Connection) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING)
self.emit(connection.EVENT_CLASSIC_PAIRING, connection)
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
self.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, connection, status)
def on_pairing_start(self, connection: Connection) -> None:
connection.emit(connection.EVENT_PAIRING_START)
self.emit(connection.EVENT_PAIRING_START, connection)
def on_pairing(
self,

View File

@@ -201,51 +201,50 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
value = data[2 : 2 + value_length]
typed_value: Any
match value_type:
case ValueType.END:
break
if value_type == ValueType.END:
break
case ValueType.CNVI | ValueType.CNVR:
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
case ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
case (
ValueType.USB_VENDOR_ID
| ValueType.USB_PRODUCT_ID
| ValueType.DEVICE_REVISION
):
(typed_value,) = struct.unpack("<H", value)
case ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
case (
ValueType.BUILD_TYPE
| ValueType.BUILD_NUMBER
| ValueType.SECURE_BOOT
| ValueType.OTP_LOCK
| ValueType.API_LOCK
| ValueType.DEBUG_LOCK
| ValueType.SECURE_BOOT_ENGINE_TYPE
):
typed_value = value[0]
case ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
case ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
case ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
case _:
typed_value = value
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]

60
bumble/gap.py Normal file
View File

@@ -0,0 +1,60 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from bumble.gatt import (
GATT_APPEARANCE_CHARACTERISTIC,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
Characteristic,
Service,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)

View File

@@ -31,7 +31,6 @@ from typing import (
ClassVar,
Generic,
Literal,
SupportsBytes,
TypeVar,
cast,
)
@@ -248,6 +247,28 @@ HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
HCI_VERSION_BLUETOOTH_CORE_1_1: 'HCI_VERSION_BLUETOOTH_CORE_1_1',
HCI_VERSION_BLUETOOTH_CORE_1_2: 'HCI_VERSION_BLUETOOTH_CORE_1_2',
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_0_EDR',
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_1_EDR',
HCI_VERSION_BLUETOOTH_CORE_3_0_HS: 'HCI_VERSION_BLUETOOTH_CORE_3_0_HS',
HCI_VERSION_BLUETOOTH_CORE_4_0: 'HCI_VERSION_BLUETOOTH_CORE_4_0',
HCI_VERSION_BLUETOOTH_CORE_4_1: 'HCI_VERSION_BLUETOOTH_CORE_4_1',
HCI_VERSION_BLUETOOTH_CORE_4_2: 'HCI_VERSION_BLUETOOTH_CORE_4_2',
HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0',
HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1',
HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2',
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
}
LMP_VERSION_NAMES = HCI_VERSION_NAMES
# HCI Packet types
HCI_COMMAND_PACKET = 0x01
HCI_ACL_DATA_PACKET = 0x02
@@ -366,8 +387,8 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2A
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2B
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
@@ -1748,6 +1769,61 @@ class CodingFormat:
)
@dataclasses.dataclass(frozen=True)
class VoiceSetting:
class AirCodingFormat(enum.IntEnum):
CVSD = 0
U_LAW = 1
A_LAW = 2
TRANSPARENT_DATA = 3
class InputSampleSize(enum.IntEnum):
SIZE_8_BITS = 0
SIZE_16_BITS = 1
class InputDataFormat(enum.IntEnum):
ONES_COMPLEMENT = 0
TWOS_COMPLEMENT = 1
SIGN_AND_MAGNITUDE = 2
UNSIGNED = 3
class InputCodingFormat(enum.IntEnum):
LINEAR = 0
U_LAW = 1
A_LAW = 2
RESERVED = 3
air_coding_format: AirCodingFormat = AirCodingFormat.CVSD
linear_pcm_bit_position: int = 0
input_sample_size: InputSampleSize = InputSampleSize.SIZE_8_BITS
input_data_format: InputDataFormat = InputDataFormat.ONES_COMPLEMENT
input_coding_format: InputCodingFormat = InputCodingFormat.LINEAR
@classmethod
def from_int(cls, value: int) -> VoiceSetting:
air_coding_format = cls.AirCodingFormat(value & 0b11)
linear_pcm_bit_position = (value >> 2) & 0b111
input_sample_size = cls.InputSampleSize((value >> 5) & 0b1)
input_data_format = cls.InputDataFormat((value >> 6) & 0b11)
input_coding_format = cls.InputCodingFormat((value >> 8) & 0b11)
return cls(
air_coding_format=air_coding_format,
linear_pcm_bit_position=linear_pcm_bit_position,
input_sample_size=input_sample_size,
input_data_format=input_data_format,
input_coding_format=input_coding_format,
)
def __int__(self) -> int:
return (
self.air_coding_format
| (self.linear_pcm_bit_position << 2)
| (self.input_sample_size << 5)
| (self.input_data_format << 6)
| (self.input_coding_format << 8)
)
# -----------------------------------------------------------------------------
class HCI_Constant:
@staticmethod
@@ -1839,46 +1915,44 @@ class HCI_Object:
field_type = field_type['parser']
# Parse the field
match field_type:
case '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
case 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_length = data[offset]
offset += 1
field_value = data[offset : offset + field_length]
return (field_value, field_length + 1)
case 1:
# 8-bit unsigned
return (data[offset], 1)
case -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
case 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
case '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
case -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
case 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
case 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
case '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
case int() if 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_type == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_length = data[offset]
offset += 1
field_value = data[offset : offset + field_length]
return (field_value, field_length + 1)
if field_type == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_type == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_type == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_type == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_type == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_type == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_type == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type):
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
@@ -1935,58 +2009,60 @@ class HCI_Object:
# Serialize the field
if serializer:
return serializer(field_value)
match field_type:
case 1:
# 8-bit unsigned
return bytes([field_value])
case -1:
# 8-bit signed
return struct.pack('b', field_value)
case 2:
# 16-bit unsigned
return struct.pack('<H', field_value)
case '>2':
# 16-bit unsigned big-endian
return struct.pack('>H', field_value)
case -2:
# 16-bit signed
return struct.pack('<h', field_value)
case 3:
# 24-bit unsigned
return struct.pack('<I', field_value)[0:3]
case 4:
# 32-bit unsigned
return struct.pack('<I', field_value)
case '>4':
# 32-bit unsigned big-endian
return struct.pack('>I', field_value)
case '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
return bytes([field_value])
else:
raise InvalidArgumentError('value too large for *-typed field')
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
return bytes(field_value)
case 'v':
# Variable-length bytes field, with 1-byte length at the beginning
raise InvalidArgumentError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
field_length = len(field_bytes)
return bytes([field_length]) + field_bytes
if isinstance(field_value, (bytes, bytearray, SupportsBytes)):
elif field_type == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_bytes = bytes(field_value)
field_length = len(field_bytes)
field_bytes = bytes([field_length]) + field_bytes
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, '__bytes__'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
return field_bytes + bytes(field_type - len(field_bytes))
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
return field_bytes[:field_type]
return field_bytes
field_bytes = field_bytes[:field_type]
else:
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes
@staticmethod
def dict_to_bytes(hci_object, object_fields):
@@ -2886,6 +2962,23 @@ class HCI_Read_Clock_Offset_Command(HCI_AsyncCommand):
connection_handle: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_Accept_Synchronous_Connection_Request_Command(HCI_AsyncCommand):
'''
See Bluetooth spec @ 7.1.27 Accept Synchronous Connection Request Command
'''
bd_addr: Address = field(metadata=metadata(Address.parse_address))
transmit_bandwidth: int = field(metadata=metadata(4))
receive_bandwidth: int = field(metadata=metadata(4))
max_latency: int = field(metadata=metadata(2))
voice_setting: int = field(metadata=metadata(2))
retransmission_effort: int = field(metadata=metadata(1))
packet_type: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -3944,6 +4037,23 @@ class HCI_Read_Local_OOB_Extended_Data_Command(
'''
# -----------------------------------------------------------------------------
@HCI_SyncCommand.sync_command(HCI_StatusReturnParameters)
@dataclasses.dataclass
class HCI_Configure_Data_Path_Command(HCI_SyncCommand[HCI_StatusReturnParameters]):
'''
See Bluetooth spec @ 7.3.101 Configure Data Path Command
'''
class DataPathDirection(SpecableEnum):
INPUT = 0x00
OUTPUT = 0x01
data_path_direction: DataPathDirection = field(metadata=metadata(1))
data_path_id: int = field(metadata=metadata(1))
vendor_specific_config: bytes = field(metadata=metadata('*'))
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class HCI_Read_Local_Version_Information_ReturnParameters(HCI_StatusReturnParameters):
@@ -4715,7 +4825,7 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_SyncCommand[HCI_StatusReturnParame
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class HCI_LE_Read_Resolving_List_Size_ReturnParameters(HCI_StatusReturnParameters):
resolving_list_size: int = field(metadata=metadata(1))
resolving_list_size: bytes = field(metadata=metadata(1))
@HCI_SyncCommand.sync_command(HCI_LE_Read_Resolving_List_Size_ReturnParameters)
@@ -7334,7 +7444,7 @@ class HCI_Connection_Complete_Event(HCI_Event):
status: int = field(metadata=metadata(STATUS_SPEC))
connection_handle: int = field(metadata=metadata(2))
bd_addr: Address = field(metadata=metadata(Address.parse_address))
link_type: int = field(metadata=LinkType.type_metadata(1))
link_type: LinkType = field(metadata=LinkType.type_metadata(1))
encryption_enabled: int = field(metadata=metadata(1))
@@ -7730,12 +7840,6 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
SCO = 0x00
ESCO = 0x02
class AirMode(SpecableEnum):
U_LAW_LOG = 0x00
A_LAW_LOG_AIR_MORE = 0x01
CVSD = 0x02
TRANSPARENT_DATA = 0x03
status: int = field(metadata=metadata(STATUS_SPEC))
connection_handle: int = field(metadata=metadata(2))
bd_addr: Address = field(metadata=metadata(Address.parse_address))
@@ -7744,7 +7848,7 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
retransmission_window: int = field(metadata=metadata(1))
rx_packet_length: int = field(metadata=metadata(2))
tx_packet_length: int = field(metadata=metadata(2))
air_mode: int = field(metadata=AirMode.type_metadata(1))
air_mode: int = field(metadata=CodecID.type_metadata(1))
# -----------------------------------------------------------------------------
@@ -7976,7 +8080,9 @@ class HCI_AclDataPacket(HCI_Packet):
bc_flag = (h >> 14) & 3
data = packet[5:]
if len(data) != data_total_length:
raise InvalidPacketError('invalid packet length')
raise InvalidPacketError(
f'invalid packet length {len(data)} != {data_total_length}'
)
return cls(
connection_handle=connection_handle,
pb_flag=pb_flag,
@@ -8009,10 +8115,16 @@ class HCI_SynchronousDataPacket(HCI_Packet):
See Bluetooth spec @ 5.4.3 HCI SCO Data Packets
'''
class Status(enum.IntEnum):
CORRECTLY_RECEIVED_DATA = 0b00
POSSIBLY_INVALID_DATA = 0b01
NO_DATA = 0b10
DATA_PARTIALLY_LOST = 0b11
hci_packet_type = HCI_SYNCHRONOUS_DATA_PACKET
connection_handle: int
packet_status: int
packet_status: Status
data_total_length: int
data: bytes
@@ -8021,7 +8133,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
# Read the header
h, data_total_length = struct.unpack_from('<HB', packet, 1)
connection_handle = h & 0xFFF
packet_status = (h >> 12) & 0b11
packet_status = cls.Status((h >> 12) & 0b11)
data = packet[4:]
if len(data) != data_total_length:
raise InvalidPacketError(
@@ -8045,7 +8157,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
return (
f'{color("SCO", "blue")}: '
f'handle=0x{self.connection_handle:04x}, '
f'ps={self.packet_status}, '
f'ps={self.packet_status.name}, '
f'data_total_length={self.data_total_length}, '
f'data={self.data.hex()}'
)
@@ -8073,8 +8185,8 @@ class HCI_IsoDataPacket(HCI_Packet):
def __post_init__(self) -> None:
self.ts_flag = self.time_stamp is not None
@staticmethod
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
@classmethod
def from_bytes(cls, packet: bytes) -> HCI_IsoDataPacket:
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
@@ -8103,7 +8215,7 @@ class HCI_IsoDataPacket(HCI_Packet):
pos += 4
iso_sdu_fragment = packet[pos:]
return HCI_IsoDataPacket(
return cls(
connection_handle=connection_handle,
pb_flag=pb_flag,
ts_flag=ts_flag,

View File

@@ -26,7 +26,7 @@ import logging
import re
import traceback
from collections.abc import Iterable
from typing import Any, ClassVar, Literal, overload
from typing import TYPE_CHECKING, Any, ClassVar
from typing_extensions import Self
@@ -166,7 +166,7 @@ class AgFeature(enum.IntFlag):
VOICE_RECOGNITION_TEXT = 0x2000
class AudioCodec(enum.IntEnum):
class AudioCodec(utils.OpenIntEnum):
"""
Audio Codec IDs (normative).
@@ -178,7 +178,7 @@ class AudioCodec(enum.IntEnum):
LC3_SWB = 0x03 # Support for LC3-SWB audio codec
class HfIndicator(enum.IntEnum):
class HfIndicator(utils.OpenIntEnum):
"""
HF Indicators (normative).
@@ -207,7 +207,7 @@ class CallHoldOperation(enum.Enum):
)
class ResponseHoldStatus(enum.IntEnum):
class ResponseHoldStatus(utils.OpenIntEnum):
"""
Response Hold status (normative).
@@ -235,7 +235,7 @@ class AgIndicator(enum.Enum):
BATTERY_CHARGE = 'battchg'
class CallSetupAgIndicator(enum.IntEnum):
class CallSetupAgIndicator(utils.OpenIntEnum):
"""
Values for the Call Setup AG indicator (normative).
@@ -248,7 +248,7 @@ class CallSetupAgIndicator(enum.IntEnum):
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
class CallHeldAgIndicator(enum.IntEnum):
class CallHeldAgIndicator(utils.OpenIntEnum):
"""
Values for the Call Held AG indicator (normative).
@@ -262,7 +262,7 @@ class CallHeldAgIndicator(enum.IntEnum):
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
class CallInfoDirection(enum.IntEnum):
class CallInfoDirection(utils.OpenIntEnum):
"""
Call Info direction (normative).
@@ -273,7 +273,7 @@ class CallInfoDirection(enum.IntEnum):
MOBILE_TERMINATED_CALL = 1
class CallInfoStatus(enum.IntEnum):
class CallInfoStatus(utils.OpenIntEnum):
"""
Call Info status (normative).
@@ -288,7 +288,7 @@ class CallInfoStatus(enum.IntEnum):
WAITING = 5
class CallInfoMode(enum.IntEnum):
class CallInfoMode(utils.OpenIntEnum):
"""
Call Info mode (normative).
@@ -301,7 +301,7 @@ class CallInfoMode(enum.IntEnum):
UNKNOWN = 9
class CallInfoMultiParty(enum.IntEnum):
class CallInfoMultiParty(utils.OpenIntEnum):
"""
Call Info Multi-Party state (normative).
@@ -388,7 +388,7 @@ class CallLineIdentification:
)
class VoiceRecognitionState(enum.IntEnum):
class VoiceRecognitionState(utils.OpenIntEnum):
"""
vrec values provided in AT+BVRA command.
@@ -401,7 +401,7 @@ class VoiceRecognitionState(enum.IntEnum):
ENHANCED_READY = 2
class CmeError(enum.IntEnum):
class CmeError(utils.OpenIntEnum):
"""
CME ERROR codes (partial listed).
@@ -420,6 +420,61 @@ class CmeError(enum.IntEnum):
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = {
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
}
# Unsolicited responses and statuses.
UNSOLICITED_CODES = {
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
}
# Status codes
STATUS_CODES = {
"+CME ERROR",
@@ -672,9 +727,12 @@ class HfProtocol(utils.EventEmitter):
dlc: rfcomm.DLC
command_lock: asyncio.Lock
pending_command: str | None = None
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray
active_codec: AudioCodec
@@ -747,39 +805,16 @@ class HfProtocol(utils.EventEmitter):
self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue.
if self.pending_command and (
response.code in STATUS_CODES or response.code in self.pending_command
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
):
self.response_queue.put_nowait(response)
else:
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
) -> None: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.SINGLE],
) -> AtResponse: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.MULTIPLE],
) -> list[AtResponse]: ...
else:
logger.warning(
f"dropping unexpected response with code '{response.code}'"
)
async def execute_command(
self,
@@ -800,34 +835,27 @@ class HfProtocol(utils.EventEmitter):
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
ProtocolError: the status is not OK.
"""
try:
async with self.command_lock:
self.pending_command = cmd
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if (
response_type == AtResponseType.SINGLE
and len(responses) != 1
):
raise HfpProtocolError("NO ANSWER")
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
finally:
self.pending_command = None
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
async def initiate_slc(self):
"""4.2.1 Service Level Connection Initialization."""
@@ -1039,6 +1067,7 @@ class HfProtocol(utils.EventEmitter):
responses = await self.execute_command(
"AT+CLCC", response_type=AtResponseType.MULTIPLE
)
assert isinstance(responses, list)
calls = []
for response in responses:
@@ -1595,7 +1624,7 @@ class AgProtocol(utils.EventEmitter):
# -----------------------------------------------------------------------------
class ProfileVersion(enum.IntEnum):
class ProfileVersion(utils.OpenIntEnum):
"""
Profile version (normative).
@@ -2047,6 +2076,7 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
max_latency=0x0008,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
@@ -2062,7 +2092,6 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
max_latency=0x000D,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5

View File

@@ -22,7 +22,7 @@ import collections
import dataclasses
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from bumble import drivers, hci, utils
from bumble.colors import color
@@ -686,14 +686,18 @@ class Host(utils.EventEmitter):
self.pending_response, timeout=response_timeout
)
return response
except asyncio.TimeoutError:
raise
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
if response is None or (
response.num_hci_command_packets and self.command_semaphore.locked()
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
):
self.command_semaphore.release()
@@ -864,7 +868,7 @@ class Host(utils.EventEmitter):
self.send_hci_packet(
hci.HCI_SynchronousDataPacket(
connection_handle=connection_handle,
packet_status=0,
packet_status=hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
data_total_length=len(sdu),
data=sdu,
)
@@ -1000,19 +1004,18 @@ class Host(utils.EventEmitter):
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
match packet:
case hci.HCI_Command():
self.on_hci_command_packet(packet)
case hci.HCI_Event():
self.on_hci_event_packet(packet)
case hci.HCI_AclDataPacket():
self.on_hci_acl_data_packet(packet)
case hci.HCI_SynchronousDataPacket():
self.on_hci_sco_data_packet(packet)
case hci.HCI_IsoDataPacket():
self.on_hci_iso_data_packet(packet)
case _:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
self.on_hci_command_packet(cast(hci.HCI_Command, packet))
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
self.on_hci_event_packet(cast(hci.HCI_Event, packet))
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}')
@@ -1176,11 +1179,28 @@ class Host(utils.EventEmitter):
def on_hci_connection_complete_event(
self, event: hci.HCI_Connection_Complete_Event
):
if event.link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
# Pass this on to the synchronous connection handler
forwarded_event = hci.HCI_Synchronous_Connection_Complete_Event(
status=event.status,
connection_handle=event.connection_handle,
bd_addr=event.bd_addr,
link_type=event.link_type,
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
self.on_hci_synchronous_connection_complete_event(forwarded_event)
return
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
f'### BR/EDR ACL CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr} '
f'{event.link_type.name}'
)
connection = self.connections.get(event.connection_handle)
@@ -1580,6 +1600,9 @@ class Host(utils.EventEmitter):
event.bd_addr,
event.connection_handle,
event.link_type,
event.rx_packet_length,
event.tx_packet_length,
event.air_mode,
)
else:
logger.debug(f'### SCO CONNECTION FAILED: {event.status}')
@@ -1658,19 +1681,6 @@ class Host(utils.EventEmitter):
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_read_remote_supported_features_complete_event(
self, event: hci.HCI_Read_Remote_Supported_Features_Complete_Event
) -> None:
# Notify the client
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.lmp_features, 'little'),
0, # page number
0, # max page number
)
def on_hci_encryption_change_v2_event(
self, event: hci.HCI_Encryption_Change_V2_Event
):
@@ -1827,18 +1837,6 @@ class Host(utils.EventEmitter):
rssi,
)
def on_hci_read_remote_extended_features_complete_event(
self, event: hci.HCI_Read_Remote_Extended_Features_Complete_Event
):
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.extended_lmp_features, 'little'),
event.page_number,
event.maximum_page_number,
)
def on_hci_extended_inquiry_result_event(
self, event: hci.HCI_Extended_Inquiry_Result_Event
):

View File

@@ -27,7 +27,6 @@ import dataclasses
import json
import logging
import os
import pathlib
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
@@ -249,26 +248,29 @@ class JsonKeyStore(KeyStore):
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(
self, namespace: str | None = None, filename: str | None = None
) -> None:
self.namespace = namespace or self.DEFAULT_NAMESPACE
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
if filename:
self.filename = pathlib.Path(filename).resolve()
self.directory_name = self.filename.parent
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else:
import platformdirs # Deferred import
self.filename = filename
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR)
self.directory_name = base_dir / self.KEYS_DIR
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
safe_name = base_name.lower().replace(':', '-').replace('/', '-')
self.filename = self.directory_name / f"{safe_name}.json"
logger.debug('JSON keystore: %s', self.filename)
logger.debug(f'JSON keystore: {self.filename}')
@classmethod
def from_device(
@@ -291,9 +293,7 @@ class JsonKeyStore(KeyStore):
return cls(namespace, filename)
async def load(
self,
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
async def load(self):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
@@ -312,17 +312,17 @@ class JsonKeyStore(KeyStore):
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map: dict[str, dict[str, Any]] = {}
key_map = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None:
async def save(self, db):
# Create the directory if it doesn't exist
if not self.directory_name.exists():
self.directory_name.mkdir(parents=True, exist_ok=True)
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name, exist_ok=True)
# Save to a temporary file
temp_filename = self.filename.with_name(self.filename.name + ".tmp")
temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)
@@ -334,16 +334,16 @@ class JsonKeyStore(KeyStore):
del key_map[name]
await self.save(db)
async def update(self, name: str, keys: PairingKeys) -> None:
async def update(self, name, keys):
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self) -> list[tuple[str, PairingKeys]]:
async def get_all(self):
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self) -> None:
async def delete_all(self):
db, key_map = await self.load()
key_map.clear()
await self.save(db)

View File

@@ -198,24 +198,3 @@ class CisTerminateInd(ControlPdu):
cig_id: int
cis_id: int
error_code: int
@dataclasses.dataclass
class FeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
feature_set: bytes
@dataclasses.dataclass
class FeatureRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
feature_set: bytes
@dataclasses.dataclass
class PeripheralFeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
feature_set: bytes

View File

@@ -322,38 +322,3 @@ class LmpNameRes(Packet):
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))
@Packet.subclass
@dataclass
class LmpFeaturesReq(Packet):
opcode = Opcode.LMP_FEATURES_REQ
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesRes(Packet):
opcode = Opcode.LMP_FEATURES_RES
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesReqExt(Packet):
opcode = Opcode.LMP_FEATURES_REQ_EXT
features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesResExt(Packet):
opcode = Opcode.LMP_FEATURES_RES_EXT
features_page: int = field(metadata=hci.metadata(1))
max_features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))

View File

@@ -21,9 +21,18 @@ import enum
import secrets
from dataclasses import dataclass
from bumble import hci, smp
from bumble import hci
from bumble.core import AdvertisingData, LeRole
from bumble.smp import (
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
@@ -87,11 +96,11 @@ class PairingDelegate:
# These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants.
class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
# Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
@@ -102,10 +111,10 @@ class PairingDelegate:
# Key Distribution [LE only]
class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY

View File

@@ -664,44 +664,46 @@ class AudioStreamControlService(gatt.TemplateService):
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
match operation:
case ASE_Config_Codec():
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Config_QOS():
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Enable() | ASE_Update_Metadata():
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case (
ASE_Receiver_Start_Ready()
| ASE_Disable()
| ASE_Receiver_Stop_Ready()
| ASE_Release()
if isinstance(operation, ASE_Config_Codec):
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, ASE_Config_QOS):
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(
operation,
(
ASE_Receiver_Start_Ready,
ASE_Disable,
ASE_Receiver_Stop_Ready,
ASE_Release,
),
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]

View File

@@ -333,18 +333,17 @@ class CodecSpecificCapabilities:
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
match type:
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
case CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment

View File

@@ -55,15 +55,14 @@ class GenericAccessService(TemplateService):
def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
):
match appearance:
case int():
appearance_int = appearance
case tuple():
appearance_int = (appearance[0] << 6) | appearance[1]
case Appearance():
appearance_int = int(appearance)
case _:
raise TypeError()
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,

View File

@@ -110,7 +110,7 @@ RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_INITIAL_CREDITS = 7
RFCOMM_DEFAULT_MAX_CREDITS = 32
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 1000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30

View File

@@ -21,12 +21,11 @@ import asyncio
import logging
import struct
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar
from typing import TYPE_CHECKING, NewType
from typing_extensions import Self
from bumble import core, hci, l2cap, utils
from bumble import core, l2cap
from bumble.colors import color
from bumble.core import (
InvalidArgumentError,
@@ -34,6 +33,7 @@ from bumble.core import (
InvalidStateError,
ProtocolError,
)
from bumble.hci import HCI_Object, key_with_value, name_or_number
if TYPE_CHECKING:
from bumble.device import Connection, Device
@@ -54,22 +54,39 @@ SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing
SDP_PSM = 0x0001
class PduId(hci.SpecableEnum):
SDP_ERROR_RESPONSE = 0x01
SDP_SERVICE_SEARCH_REQUEST = 0x02
SDP_SERVICE_SEARCH_RESPONSE = 0x03
SDP_SERVICE_ATTRIBUTE_REQUEST = 0x04
SDP_SERVICE_ATTRIBUTE_RESPONSE = 0x05
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST = 0x06
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
SDP_ERROR_RESPONSE = 0x01
SDP_SERVICE_SEARCH_REQUEST = 0x02
SDP_SERVICE_SEARCH_RESPONSE = 0x03
SDP_SERVICE_ATTRIBUTE_REQUEST = 0x04
SDP_SERVICE_ATTRIBUTE_RESPONSE = 0x05
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST = 0x06
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
class ErrorCode(hci.SpecableEnum):
INVALID_SDP_VERSION = 0x0001
INVALID_SERVICE_RECORD_HANDLE = 0x0002
INVALID_REQUEST_SYNTAX = 0x0003
INVALID_PDU_SIZE = 0x0004
INVALID_CONTINUATION_STATE = 0x0005
INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST = 0x0006
SDP_PDU_NAMES = {
SDP_ERROR_RESPONSE: 'SDP_ERROR_RESPONSE',
SDP_SERVICE_SEARCH_REQUEST: 'SDP_SERVICE_SEARCH_REQUEST',
SDP_SERVICE_SEARCH_RESPONSE: 'SDP_SERVICE_SEARCH_RESPONSE',
SDP_SERVICE_ATTRIBUTE_REQUEST: 'SDP_SERVICE_ATTRIBUTE_REQUEST',
SDP_SERVICE_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_ATTRIBUTE_RESPONSE',
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: 'SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST',
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE'
}
SDP_INVALID_SDP_VERSION_ERROR = 0x0001
SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR = 0x0002
SDP_INVALID_REQUEST_SYNTAX_ERROR = 0x0003
SDP_INVALID_PDU_SIZE_ERROR = 0x0004
SDP_INVALID_CONTINUATION_STATE_ERROR = 0x0005
SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR = 0x0006
SDP_ERROR_NAMES = {
SDP_INVALID_SDP_VERSION_ERROR: 'SDP_INVALID_SDP_VERSION_ERROR',
SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR: 'SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR',
SDP_INVALID_REQUEST_SYNTAX_ERROR: 'SDP_INVALID_REQUEST_SYNTAX_ERROR',
SDP_INVALID_PDU_SIZE_ERROR: 'SDP_INVALID_PDU_SIZE_ERROR',
SDP_INVALID_CONTINUATION_STATE_ERROR: 'SDP_INVALID_CONTINUATION_STATE_ERROR',
SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR: 'SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR'
}
SDP_SERVICE_NAME_ATTRIBUTE_ID_OFFSET = 0x0000
SDP_SERVICE_DESCRIPTION_ATTRIBUTE_ID_OFFSET = 0x0001
@@ -124,31 +141,30 @@ SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF)
# -----------------------------------------------------------------------------
@dataclass
class DataElement:
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
class Type(utils.OpenIntEnum):
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
TYPE_NAMES = {
NIL: 'NIL',
UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL',
}
NIL = Type.NIL
UNSIGNED_INTEGER = Type.UNSIGNED_INTEGER
SIGNED_INTEGER = Type.SIGNED_INTEGER
UUID = Type.UUID
TEXT_STRING = Type.TEXT_STRING
BOOLEAN = Type.BOOLEAN
SEQUENCE = Type.SEQUENCE
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
@@ -174,18 +190,14 @@ class DataElement:
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
def __post_init__(self) -> None:
def __init__(self, element_type, value, value_size=None):
self.type = element_type
self.value = value
self.value_size = value_size
# Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
self._bytes: bytes | None = None
if self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
if self.value_size is None:
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise InvalidArgumentError(
'integer types must have a value size specified'
)
@@ -325,7 +337,7 @@ class DataElement:
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
constructor = DataElement.type_constructors.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
@@ -336,15 +348,15 @@ class DataElement:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
# Return early if we have a cache
if self._bytes:
return self._bytes
if self.bytes:
return self.bytes
if self.type == DataElement.NIL:
data = b''
@@ -431,12 +443,12 @@ class DataElement:
else:
raise RuntimeError("internal error - self.type not supported")
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self.bytes
def to_string(self, pretty=False, indentation=0):
prefix = ' ' * indentation
type_name = self.type.name
type_name = name_or_number(self.TYPE_NAMES, self.type)
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
@@ -464,10 +476,10 @@ class DataElement:
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
id: int
value: DataElement
def __init__(self, attribute_id: int, value: DataElement) -> None:
self.id = attribute_id
self.value = value
@staticmethod
def list_from_data_elements(
@@ -498,7 +510,7 @@ class ServiceAttribute:
@staticmethod
def id_name(id_code):
return hci.name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
@@ -528,223 +540,239 @@ class ServiceAttribute:
# -----------------------------------------------------------------------------
def _parse_service_record_handle_list(
data: bytes, offset: int
) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset)[0]
offset += 2
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
def _serialize_service_record_handle_list(
handles: list[int],
) -> bytes:
return struct.pack('>H', len(handles)) + b''.join(
struct.pack('>I', handle) for handle in handles
)
def _parse_bytes_preceded_by_length(data: bytes, offset: int) -> tuple[int, bytes]:
length = struct.unpack_from('>H', data, offset)[0]
offset += 2
return offset + length, data[offset : offset + length]
def _serialize_bytes_preceded_by_length(data: bytes) -> bytes:
return struct.pack('>H', len(data)) + data
_SERVICE_RECORD_HANDLE_LIST_METADATA = hci.metadata(
{
'parser': _parse_service_record_handle_list,
'serializer': _serialize_service_record_handle_list,
}
)
_BYTES_PRECEDED_BY_LENGTH_METADATA = hci.metadata(
{
'parser': _parse_bytes_preceded_by_length,
'serializer': _serialize_bytes_preceded_by_length,
}
)
# -----------------------------------------------------------------------------
@dataclass
class SDP_PDU:
'''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
RESPONSE_PDU_IDS = {
PduId.SDP_SERVICE_SEARCH_REQUEST: PduId.SDP_SERVICE_SEARCH_RESPONSE,
PduId.SDP_SERVICE_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE,
PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
}
subclasses: ClassVar[dict[int, type[SDP_PDU]]] = {}
pdu_id: ClassVar[PduId]
fields: ClassVar[hci.Fields]
sdp_pdu_classes: dict[int, type[SDP_PDU]] = {}
name = None
pdu_id = 0
transaction_id: int
_payload: bytes | None = field(init=False, repr=False, default=None)
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
@staticmethod
def from_bytes(pdu):
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
raise InvalidPacketError(f"Unknown PDU type {pdu_id}")
instance = subclass(
transaction_id=transaction_id,
**hci.HCI_Object.dict_from_bytes(pdu, 5, subclass.fields),
)
instance._payload = pdu
return instance
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None:
instance = SDP_PDU(pdu)
instance.name = SDP_PDU.pdu_name(pdu_id)
instance.pdu_id = pdu_id
instance.transaction_id = transaction_id
return instance
self = cls.__new__(cls)
SDP_PDU.__init__(self, pdu, transaction_id)
if hasattr(self, 'fields'):
self.init_from_bytes(pdu, 5)
return self
_PDU = TypeVar('_PDU', bound='SDP_PDU')
@staticmethod
def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int
) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
@classmethod
def subclass(cls, subclass: type[_PDU]) -> type[_PDU]:
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
cls.subclasses[subclass.pdu_id] = subclass
return subclass
@staticmethod
def parse_bytes_preceded_by_length(data, offset):
length = struct.unpack_from('>H', data, offset - 2)[0]
return offset + length, data[offset : offset + length]
@staticmethod
def error_name(error_code):
return name_or_number(SDP_ERROR_NAMES, error_code)
@staticmethod
def pdu_name(code):
return name_or_number(SDP_PDU_NAMES, code)
@staticmethod
def subclass(fields):
def inner(cls):
name = cls.__name__
# add a _ character before every uppercase letter, except the SDP_ prefix
location = len(name) - 1
while location > 4:
if not name[location].isupper():
location -= 1
continue
name = name[:location] + '_' + name[location:]
location -= 1
cls.name = name.upper()
cls.pdu_id = key_with_value(SDP_PDU_NAMES, cls.name)
if cls.pdu_id is None:
raise KeyError(f'PDU name {cls.name} not found in SDP_PDU_NAMES')
cls.fields = fields
# Register a factory for this class
SDP_PDU.sdp_pdu_classes[cls.pdu_id] = cls
return cls
return inner
def __init__(self, pdu=None, transaction_id=0, **kwargs):
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu
self.transaction_id = transaction_id
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self):
if self._payload is None:
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@property
def name(self) -> str:
return self.pdu_id.name
return self.pdu
def __str__(self):
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ')
elif len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
class SDP_ErrorResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
'''
pdu_id = PduId.SDP_ERROR_RESPONSE
error_code: ErrorCode = field(metadata=ErrorCode.type_metadata(2))
error_code: int
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
'''
pdu_id = PduId.SDP_SERVICE_SEARCH_REQUEST
service_search_pattern: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
maximum_service_record_count: int = field(metadata=hci.metadata('>2'))
continuation_state: bytes = field(metadata=hci.metadata('*'))
service_search_pattern: DataElement
maximum_service_record_count: int
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
(
'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
'''
pdu_id = PduId.SDP_SERVICE_SEARCH_RESPONSE
total_service_record_count: int = field(metadata=hci.metadata('>2'))
service_record_handle_list: Sequence[int] = field(
metadata=_SERVICE_RECORD_HANDLE_LIST_METADATA
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
service_record_handle_list: list[int]
total_service_record_count: int
current_service_record_count: int
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
'''
pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_REQUEST
service_record_handle: int = field(metadata=hci.metadata('>4'))
maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2'))
attribute_id_list: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
service_record_handle: int
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
'''
pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE
attribute_list: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA)
continuation_state: bytes = field(metadata=hci.metadata('*'))
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
'''
pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST
service_search_pattern: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2'))
attribute_id_list: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
service_search_pattern: DataElement
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass
@dataclass
@SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
'''
pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE
attribute_lists: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA)
continuation_state: bytes = field(metadata=hci.metadata('*'))
attribute_lists_byte_count: int
attribute_lists: bytes
continuation_state: bytes
# -----------------------------------------------------------------------------
@@ -845,7 +873,7 @@ class Client:
)
# Request and accumulate until there's no more continuation
service_record_handle_list: list[int] = []
service_record_handle_list = []
continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0:
@@ -1063,7 +1091,7 @@ class Server:
logger.exception(color('failed to parse SDP Request PDU', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id=0, error_code=ErrorCode.INVALID_REQUEST_SYNTAX
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
)
)
@@ -1080,7 +1108,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=ErrorCode.INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST,
error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
)
)
else:
@@ -1088,7 +1116,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=ErrorCode.INVALID_REQUEST_SYNTAX,
error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
)
)
@@ -1106,7 +1134,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=transaction_id,
error_code=ErrorCode.INVALID_CONTINUATION_STATE,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return None
@@ -1200,11 +1228,15 @@ class Server:
if service_record_handles_remaining
else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
)
self.send_response(
SDP_ServiceSearchResponse(
transaction_id=request.transaction_id,
total_service_record_count=total_service_record_count,
service_record_handle_list=service_record_handles,
current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state,
)
)
@@ -1227,7 +1259,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=ErrorCode.INVALID_SERVICE_RECORD_HANDLE,
error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
)
)
return
@@ -1252,6 +1284,7 @@ class Server:
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list_response,
continuation_state=continuation_state,
)
@@ -1298,6 +1331,7 @@ class Server:
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists_response,
continuation_state=continuation_state,
)

View File

@@ -31,13 +31,14 @@ from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
from bumble import crypto, hci, utils
from bumble import crypto, utils
from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
PhysicalTransport,
ProtocolError,
name_or_number,
)
from bumble.hci import (
Address,
@@ -45,6 +46,7 @@ from bumble.hci import (
HCI_LE_Enable_Encryption_Command,
HCI_Object,
Role,
key_with_value,
metadata,
)
from bumble.keys import PairingKeys
@@ -69,125 +71,115 @@ logger = logging.getLogger(__name__)
SMP_CID = 0x06
SMP_BR_CID = 0x07
class CommandCode(hci.SpecableEnum):
PAIRING_REQUEST = 0x01
PAIRING_RESPONSE = 0x02
PAIRING_CONFIRM = 0x03
PAIRING_RANDOM = 0x04
PAIRING_FAILED = 0x05
ENCRYPTION_INFORMATION = 0x06
MASTER_IDENTIFICATION = 0x07
IDENTITY_INFORMATION = 0x08
IDENTITY_ADDRESS_INFORMATION = 0x09
SIGNING_INFORMATION = 0x0A
SECURITY_REQUEST = 0x0B
PAIRING_PUBLIC_KEY = 0x0C
PAIRING_DHKEY_CHECK = 0x0D
PAIRING_KEYPRESS_NOTIFICATION = 0x0E
SMP_PAIRING_REQUEST_COMMAND = 0x01
SMP_PAIRING_RESPONSE_COMMAND = 0x02
SMP_PAIRING_CONFIRM_COMMAND = 0x03
SMP_PAIRING_RANDOM_COMMAND = 0x04
SMP_PAIRING_FAILED_COMMAND = 0x05
SMP_ENCRYPTION_INFORMATION_COMMAND = 0x06
SMP_MASTER_IDENTIFICATION_COMMAND = 0x07
SMP_IDENTITY_INFORMATION_COMMAND = 0x08
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND = 0x09
SMP_SIGNING_INFORMATION_COMMAND = 0x0A
SMP_SECURITY_REQUEST_COMMAND = 0x0B
SMP_PAIRING_PUBLIC_KEY_COMMAND = 0x0C
SMP_PAIRING_DHKEY_CHECK_COMMAND = 0x0D
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND = 0x0E
SMP_COMMAND_NAMES = {
SMP_PAIRING_REQUEST_COMMAND: 'SMP_PAIRING_REQUEST_COMMAND',
SMP_PAIRING_RESPONSE_COMMAND: 'SMP_PAIRING_RESPONSE_COMMAND',
SMP_PAIRING_CONFIRM_COMMAND: 'SMP_PAIRING_CONFIRM_COMMAND',
SMP_PAIRING_RANDOM_COMMAND: 'SMP_PAIRING_RANDOM_COMMAND',
SMP_PAIRING_FAILED_COMMAND: 'SMP_PAIRING_FAILED_COMMAND',
SMP_ENCRYPTION_INFORMATION_COMMAND: 'SMP_ENCRYPTION_INFORMATION_COMMAND',
SMP_MASTER_IDENTIFICATION_COMMAND: 'SMP_MASTER_IDENTIFICATION_COMMAND',
SMP_IDENTITY_INFORMATION_COMMAND: 'SMP_IDENTITY_INFORMATION_COMMAND',
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND: 'SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND',
SMP_SIGNING_INFORMATION_COMMAND: 'SMP_SIGNING_INFORMATION_COMMAND',
SMP_SECURITY_REQUEST_COMMAND: 'SMP_SECURITY_REQUEST_COMMAND',
SMP_PAIRING_PUBLIC_KEY_COMMAND: 'SMP_PAIRING_PUBLIC_KEY_COMMAND',
SMP_PAIRING_DHKEY_CHECK_COMMAND: 'SMP_PAIRING_DHKEY_CHECK_COMMAND',
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND: 'SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND'
}
class IoCapability(hci.SpecableEnum):
DISPLAY_ONLY = 0x00
DISPLAY_YES_NO = 0x01
KEYBOARD_ONLY = 0x02
NO_INPUT_NO_OUTPUT = 0x03
KEYBOARD_DISPLAY = 0x04
SMP_DISPLAY_ONLY_IO_CAPABILITY = 0x00
SMP_DISPLAY_YES_NO_IO_CAPABILITY = 0x01
SMP_KEYBOARD_ONLY_IO_CAPABILITY = 0x02
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = 0x03
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = 0x04
SMP_DISPLAY_ONLY_IO_CAPABILITY = IoCapability.DISPLAY_ONLY
SMP_DISPLAY_YES_NO_IO_CAPABILITY = IoCapability.DISPLAY_YES_NO
SMP_KEYBOARD_ONLY_IO_CAPABILITY = IoCapability.KEYBOARD_ONLY
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = IoCapability.NO_INPUT_NO_OUTPUT
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = IoCapability.KEYBOARD_DISPLAY
SMP_IO_CAPABILITY_NAMES = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: 'SMP_DISPLAY_ONLY_IO_CAPABILITY',
SMP_DISPLAY_YES_NO_IO_CAPABILITY: 'SMP_DISPLAY_YES_NO_IO_CAPABILITY',
SMP_KEYBOARD_ONLY_IO_CAPABILITY: 'SMP_KEYBOARD_ONLY_IO_CAPABILITY',
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: 'SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY',
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: 'SMP_KEYBOARD_DISPLAY_IO_CAPABILITY'
}
class ErrorCode(hci.SpecableEnum):
PASSKEY_ENTRY_FAILED = 0x01
OOB_NOT_AVAILABLE = 0x02
AUTHENTICATION_REQUIREMENTS = 0x03
CONFIRM_VALUE_FAILED = 0x04
PAIRING_NOT_SUPPORTED = 0x05
ENCRYPTION_KEY_SIZE = 0x06
COMMAND_NOT_SUPPORTED = 0x07
UNSPECIFIED_REASON = 0x08
REPEATED_ATTEMPTS = 0x09
INVALID_PARAMETERS = 0x0A
DHKEY_CHECK_FAILED = 0x0B
NUMERIC_COMPARISON_FAILED = 0x0C
BD_EDR_PAIRING_IN_PROGRESS = 0x0D
CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED = 0x0E
SMP_PASSKEY_ENTRY_FAILED_ERROR = 0x01
SMP_OOB_NOT_AVAILABLE_ERROR = 0x02
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = 0x03
SMP_CONFIRM_VALUE_FAILED_ERROR = 0x04
SMP_PAIRING_NOT_SUPPORTED_ERROR = 0x05
SMP_ENCRYPTION_KEY_SIZE_ERROR = 0x06
SMP_COMMAND_NOT_SUPPORTED_ERROR = 0x07
SMP_UNSPECIFIED_REASON_ERROR = 0x08
SMP_REPEATED_ATTEMPTS_ERROR = 0x09
SMP_INVALID_PARAMETERS_ERROR = 0x0A
SMP_DHKEY_CHECK_FAILED_ERROR = 0x0B
SMP_NUMERIC_COMPARISON_FAILED_ERROR = 0x0C
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = 0x0D
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = 0x0E
SMP_PASSKEY_ENTRY_FAILED_ERROR = ErrorCode.PASSKEY_ENTRY_FAILED
SMP_OOB_NOT_AVAILABLE_ERROR = ErrorCode.OOB_NOT_AVAILABLE
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = ErrorCode.AUTHENTICATION_REQUIREMENTS
SMP_CONFIRM_VALUE_FAILED_ERROR = ErrorCode.CONFIRM_VALUE_FAILED
SMP_PAIRING_NOT_SUPPORTED_ERROR = ErrorCode.PAIRING_NOT_SUPPORTED
SMP_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.ENCRYPTION_KEY_SIZE
SMP_COMMAND_NOT_SUPPORTED_ERROR = ErrorCode.COMMAND_NOT_SUPPORTED
SMP_UNSPECIFIED_REASON_ERROR = ErrorCode.UNSPECIFIED_REASON
SMP_REPEATED_ATTEMPTS_ERROR = ErrorCode.REPEATED_ATTEMPTS
SMP_INVALID_PARAMETERS_ERROR = ErrorCode.INVALID_PARAMETERS
SMP_DHKEY_CHECK_FAILED_ERROR = ErrorCode.DHKEY_CHECK_FAILED
SMP_NUMERIC_COMPARISON_FAILED_ERROR = ErrorCode.NUMERIC_COMPARISON_FAILED
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = ErrorCode.BD_EDR_PAIRING_IN_PROGRESS
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
SMP_ERROR_NAMES = {
SMP_PASSKEY_ENTRY_FAILED_ERROR: 'SMP_PASSKEY_ENTRY_FAILED_ERROR',
SMP_OOB_NOT_AVAILABLE_ERROR: 'SMP_OOB_NOT_AVAILABLE_ERROR',
SMP_AUTHENTICATION_REQUIREMENTS_ERROR: 'SMP_AUTHENTICATION_REQUIREMENTS_ERROR',
SMP_CONFIRM_VALUE_FAILED_ERROR: 'SMP_CONFIRM_VALUE_FAILED_ERROR',
SMP_PAIRING_NOT_SUPPORTED_ERROR: 'SMP_PAIRING_NOT_SUPPORTED_ERROR',
SMP_ENCRYPTION_KEY_SIZE_ERROR: 'SMP_ENCRYPTION_KEY_SIZE_ERROR',
SMP_COMMAND_NOT_SUPPORTED_ERROR: 'SMP_COMMAND_NOT_SUPPORTED_ERROR',
SMP_UNSPECIFIED_REASON_ERROR: 'SMP_UNSPECIFIED_REASON_ERROR',
SMP_REPEATED_ATTEMPTS_ERROR: 'SMP_REPEATED_ATTEMPTS_ERROR',
SMP_INVALID_PARAMETERS_ERROR: 'SMP_INVALID_PARAMETERS_ERROR',
SMP_DHKEY_CHECK_FAILED_ERROR: 'SMP_DHKEY_CHECK_FAILED_ERROR',
SMP_NUMERIC_COMPARISON_FAILED_ERROR: 'SMP_NUMERIC_COMPARISON_FAILED_ERROR',
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR: 'SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR',
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR: 'SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR'
}
class KeypressNotificationType(hci.SpecableEnum):
PASSKEY_ENTRY_STARTED = 0
PASSKEY_DIGIT_ENTERED = 1
PASSKEY_DIGIT_ERASED = 2
PASSKEY_CLEARED = 3
PASSKEY_ENTRY_COMPLETED = 4
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE = 0
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE = 1
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE = 2
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE = 3
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE = 4
SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES = {
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE'
}
# Bit flags for key distribution/generation
class KeyDistribution(hci.SpecableFlag):
ENC_KEY = 0b0001
ID_KEY = 0b0010
SIGN_KEY = 0b0100
LINK_KEY = 0b1000
SMP_ENC_KEY_DISTRIBUTION_FLAG = 0b0001
SMP_ID_KEY_DISTRIBUTION_FLAG = 0b0010
SMP_SIGN_KEY_DISTRIBUTION_FLAG = 0b0100
SMP_LINK_KEY_DISTRIBUTION_FLAG = 0b1000
# AuthReq fields
class AuthReq(hci.SpecableFlag):
BONDING = 0b00000001
MITM = 0b00000100
SC = 0b00001000
KEYPRESS = 0b00010000
CT2 = 0b00100000
@classmethod
def from_booleans(
cls,
bonding: bool = False,
sc: bool = False,
mitm: bool = False,
keypress: bool = False,
ct2: bool = False,
) -> AuthReq:
auth_req = AuthReq(0)
if bonding:
auth_req |= AuthReq.BONDING
if sc:
auth_req |= AuthReq.SC
if mitm:
auth_req |= AuthReq.MITM
if keypress:
auth_req |= AuthReq.KEYPRESS
if ct2:
auth_req |= AuthReq.CT2
return auth_req
SMP_BONDING_AUTHREQ = 0b00000001
SMP_MITM_AUTHREQ = 0b00000100
SMP_SC_AUTHREQ = 0b00001000
SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000
# Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# Diffie-Hellman private / public key pair in Debug Mode (Core - Vol. 3, Part H)
SMP_DEBUG_KEY_PRIVATE = bytes.fromhex(
'3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd'
)
SMP_DEBUG_KEY_PUBLIC_X = bytes.fromhex(
'20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
)
SMP_DEBUG_KEY_PUBLIC_Y= bytes.fromhex(
'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b'
)
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
@@ -196,6 +188,8 @@ SMP_DEBUG_KEY_PUBLIC_Y= bytes.fromhex(
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code)
# -----------------------------------------------------------------------------
@@ -207,20 +201,20 @@ class SMP_Command:
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
'''
smp_classes: ClassVar[dict[CommandCode, type[SMP_Command]]] = {}
smp_classes: ClassVar[dict[int, type[SMP_Command]]] = {}
fields: ClassVar[Fields]
code: CommandCode = field(default=CommandCode(0), init=False)
code: int = field(default=0, init=False)
name: str = field(default='', init=False)
_payload: bytes | None = field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
code = CommandCode(pdu[0])
code = pdu[0]
subclass = SMP_Command.smp_classes.get(code)
if subclass is None:
instance = SMP_Command()
instance.name = code.name
instance.name = SMP_Command.command_name(code)
instance.code = code
instance.payload = pdu
return instance
@@ -228,14 +222,59 @@ class SMP_Command:
instance.payload = pdu[1:]
return instance
@staticmethod
def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod
def auth_req_str(value: int) -> str:
bonding_flags = value & 3
mitm = (value >> 2) & 1
sc = (value >> 3) & 1
keypress = (value >> 4) & 1
ct2 = (value >> 5) & 1
return (
f'bonding_flags={bonding_flags}, '
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
)
@staticmethod
def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod
def key_distribution_str(value: int) -> str:
key_types: list[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
key_types.append('ID')
if value & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
key_types.append('SIGN')
if value & SMP_LINK_KEY_DISTRIBUTION_FLAG:
key_types.append('LINK')
return ','.join(key_types)
@staticmethod
def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
_Command = TypeVar("_Command", bound="SMP_Command")
@classmethod
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
subclass.name = subclass.__name__.upper()
subclass.code = key_with_value(SMP_COMMAND_NAMES, subclass.name)
if subclass.code is None:
raise KeyError(
f'Command name {subclass.name} not found in SMP_COMMAND_NAMES'
)
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
# Register a factory for this class
SMP_Command.smp_classes[subclass.code] = subclass
return subclass
@property
@@ -269,17 +308,19 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
'''
code = CommandCode.PAIRING_REQUEST
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
)
@@ -291,17 +332,19 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
'''
code = CommandCode.PAIRING_RESPONSE
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
)
@@ -313,8 +356,6 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
'''
code = CommandCode.PAIRING_CONFIRM
confirm_value: bytes = field(metadata=metadata(16))
@@ -326,8 +367,6 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
'''
code = CommandCode.PAIRING_RANDOM
random_value: bytes = field(metadata=metadata(16))
@@ -339,9 +378,7 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
'''
code = CommandCode.PAIRING_FAILED
reason: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
reason: int = field(metadata=metadata({'size': 1, 'mapper': error_name}))
# -----------------------------------------------------------------------------
@@ -352,8 +389,6 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
'''
code = CommandCode.PAIRING_PUBLIC_KEY
public_key_x: bytes = field(metadata=metadata(32))
public_key_y: bytes = field(metadata=metadata(32))
@@ -366,8 +401,6 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
'''
code = CommandCode.PAIRING_DHKEY_CHECK
dhkey_check: bytes = field(metadata=metadata(16))
@@ -379,10 +412,10 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
'''
code = CommandCode.PAIRING_KEYPRESS_NOTIFICATION
notification_type: KeypressNotificationType = field(
metadata=KeypressNotificationType.type_metadata(1)
notification_type: int = field(
metadata=metadata(
{'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}
)
)
@@ -394,8 +427,6 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
'''
code = CommandCode.ENCRYPTION_INFORMATION
long_term_key: bytes = field(metadata=metadata(16))
@@ -407,8 +438,6 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
'''
code = CommandCode.MASTER_IDENTIFICATION
ediv: int = field(metadata=metadata(2))
rand: bytes = field(metadata=metadata(8))
@@ -421,8 +450,6 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
'''
code = CommandCode.IDENTITY_INFORMATION
identity_resolving_key: bytes = field(metadata=metadata(16))
@@ -434,8 +461,6 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
'''
code = CommandCode.IDENTITY_ADDRESS_INFORMATION
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
@@ -448,8 +473,6 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
'''
code = CommandCode.SIGNING_INFORMATION
signature_key: bytes = field(metadata=metadata(16))
@@ -461,9 +484,25 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
'''
code = CommandCode.SECURITY_REQUEST
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
# -----------------------------------------------------------------------------
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0
if bonding:
value |= SMP_BONDING_AUTHREQ
if mitm:
value |= SMP_MITM_AUTHREQ
if sc:
value |= SMP_SC_AUTHREQ
if keypress:
value |= SMP_KEYPRESS_AUTHREQ
if ct2:
value |= SMP_CT2_AUTHREQ
return value
# -----------------------------------------------------------------------------
@@ -637,8 +676,8 @@ class Session:
self.ltk_rand = bytes(8)
self.link_key: bytes | None = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: KeyDistribution = KeyDistribution(0)
self.responder_key_distribution: KeyDistribution = KeyDistribution(0)
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: bytes | None = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
@@ -689,10 +728,10 @@ class Session:
)
# Key Distribution (default values before negotiation)
self.initiator_key_distribution = KeyDistribution(
self.initiator_key_distribution = (
pairing_config.delegate.local_initiator_key_distribution
)
self.responder_key_distribution = KeyDistribution(
self.responder_key_distribution = (
pairing_config.delegate.local_responder_key_distribution
)
@@ -704,7 +743,7 @@ class Session:
self.ct2: bool = False
# I/O Capabilities
self.io_capability = IoCapability(pairing_config.delegate.io_capability)
self.io_capability = pairing_config.delegate.io_capability
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
@@ -783,14 +822,8 @@ class Session:
return self.nx[0 if self.is_responder else 1]
@property
def auth_req(self) -> AuthReq:
return AuthReq.from_booleans(
bonding=self.bonding,
sc=self.sc,
mitm=self.mitm,
keypress=self.keypress,
ct2=self.ct2,
)
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
if not self.sc and not self.completed:
@@ -810,7 +843,7 @@ class Session:
if self.connection.transport == PhysicalTransport.BR_EDR:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return
if (not self.mitm) and (auth_req & AuthReq.MITM == 0):
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = PairingMethod.JUST_WORKS
return
@@ -828,7 +861,7 @@ class Session:
self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(
self, expected: bytes, received: bytes, error: ErrorCode
self, expected: bytes, received: bytes, error: int
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received:
@@ -848,7 +881,7 @@ class Session:
except Exception:
logger.exception('exception while confirm')
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.connection.cancel_on_disconnection(prompt())
@@ -867,7 +900,7 @@ class Session:
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.connection.cancel_on_disconnection(prompt())
@@ -878,13 +911,13 @@ class Session:
passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
return
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.connection.cancel_on_disconnection(prompt())
@@ -939,7 +972,7 @@ class Session:
def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error: ErrorCode) -> None:
def send_pairing_failed(self, error: int) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error)
@@ -1111,7 +1144,7 @@ class Session:
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
@@ -1122,14 +1155,14 @@ class Session:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.initiator_key_distribution & KeyDistribution.ENC_KEY
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & KeyDistribution.ENC_KEY:
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1140,7 +1173,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.initiator_key_distribution & KeyDistribution.ID_KEY:
if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1150,25 +1183,25 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.initiator_key_distribution & KeyDistribution.SIGN_KEY:
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & KeyDistribution.LINK_KEY:
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
else:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.responder_key_distribution & KeyDistribution.ENC_KEY
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
if self.responder_key_distribution & KeyDistribution.ENC_KEY:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1179,7 +1212,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.responder_key_distribution & KeyDistribution.ID_KEY:
if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1189,30 +1222,30 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.responder_key_distribution & KeyDistribution.SIGN_KEY:
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & KeyDistribution.LINK_KEY:
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = []
if not self.sc and self.connection.transport == PhysicalTransport.LE:
if key_distribution_flags & KeyDistribution.ENC_KEY != 0:
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
self.peer_expected_distributions.append(
SMP_Encryption_Information_Command
)
self.peer_expected_distributions.append(
SMP_Master_Identification_Command
)
if key_distribution_flags & KeyDistribution.ID_KEY != 0:
if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0:
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
self.peer_expected_distributions.append(
SMP_Identity_Address_Information_Command
)
if key_distribution_flags & KeyDistribution.SIGN_KEY != 0:
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
logger.debug(
'expecting distributions: '
@@ -1225,7 +1258,7 @@ class Session:
logger.warning(
color('received key distribution on a non-encrypted connection', 'red')
)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
return
# Check that this command class is expected
@@ -1245,7 +1278,7 @@ class Session:
'red',
)
)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
async def pair(self) -> None:
# Start pairing as an initiator
@@ -1356,56 +1389,34 @@ class Session:
)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: ErrorCode) -> None:
logger.warning('pairing failure (%s)', reason.name)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
if self.completed:
return
self.completed = True
error = ProtocolError(reason, 'smp', reason.name)
error = ProtocolError(reason, 'smp', error_name(reason))
if self.pairing_result is not None and not self.pairing_result.done():
self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command: SMP_Command) -> None:
try:
match command:
case SMP_Pairing_Request_Command():
self.on_smp_pairing_request_command(command)
case SMP_Pairing_Response_Command():
self.on_smp_pairing_response_command(command)
case SMP_Pairing_Confirm_Command():
self.on_smp_pairing_confirm_command(command)
case SMP_Pairing_Random_Command():
self.on_smp_pairing_random_command(command)
case SMP_Pairing_Failed_Command():
self.on_smp_pairing_failed_command(command)
case SMP_Encryption_Information_Command():
self.on_smp_encryption_information_command(command)
case SMP_Master_Identification_Command():
self.on_smp_master_identification_command(command)
case SMP_Identity_Information_Command():
self.on_smp_identity_information_command(command)
case SMP_Identity_Address_Information_Command():
self.on_smp_identity_address_information_command(command)
case SMP_Signing_Information_Command():
self.on_smp_signing_information_command(command)
case SMP_Pairing_Public_Key_Command():
self.on_smp_pairing_public_key_command(command)
case SMP_Pairing_DHKey_Check_Command():
self.on_smp_pairing_dhkey_check_command(command)
# case SMP_Security_Request_Command():
# self.on_smp_security_request_command(command)
# case SMP_Pairing_Keypress_Notification_Command():
# self.on_smp_pairing_keypress_notification_command(command)
case _:
logger.error(color('SMP command not handled', 'red'))
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(reason=ErrorCode.UNSPECIFIED_REASON)
self.send_command(response)
# Find the handler method
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(command)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
self.send_command(response)
else:
logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
@@ -1425,16 +1436,16 @@ class Session:
accepted = False
if not accepted:
logger.debug('pairing rejected by delegate')
self.send_pairing_failed(ErrorCode.PAIRING_NOT_SUPPORTED)
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
return
# Save the request
self.preq = bytes(command)
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
self.ct2 = self.ct2 and (command.auth_req & AuthReq.CT2 != 0)
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1445,7 +1456,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1464,11 +1475,8 @@ class Session:
(
self.initiator_key_distribution,
self.responder_key_distribution,
) = map(
KeyDistribution,
await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
),
) = await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
)
self.compute_peer_expected_distributions(self.initiator_key_distribution)
@@ -1506,8 +1514,8 @@ class Session:
self.peer_io_capability = command.io_capability
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1518,7 +1526,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1538,7 +1546,7 @@ class Session:
command.responder_key_distribution & ~self.responder_key_distribution != 0
):
# The response isn't a subset of the request
self.send_pairing_failed(ErrorCode.INVALID_PARAMETERS)
self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR)
return
self.initiator_key_distribution = command.initiator_key_distribution
self.responder_key_distribution = command.responder_key_distribution
@@ -1616,7 +1624,7 @@ class Session:
)
assert self.confirm_value
if not self.check_expected_value(
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
@@ -1657,7 +1665,7 @@ class Session:
self.pkb, self.pka, command.random_value, bytes([0])
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1670,7 +1678,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
@@ -1699,7 +1707,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
@@ -1816,7 +1824,7 @@ class Session:
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
ErrorCode.CONFIRM_VALUE_FAILED,
SMP_CONFIRM_VALUE_FAILED_ERROR,
):
return
@@ -1850,7 +1858,7 @@ class Session:
expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value(
expected, command.dhkey_check, ErrorCode.DHKEY_CHECK_FAILED
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
):
return
@@ -1929,7 +1937,6 @@ class Manager(utils.EventEmitter):
self._ecc_key = None
self.pairing_config_factory = pairing_config_factory
self.session_proxy = Session
self.debug_mode = False
def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug(
@@ -1955,7 +1962,7 @@ class Manager(utils.EventEmitter):
)
# Security request is more than just pairing, so let applications handle them
if command.code == CommandCode.SECURITY_REQUEST:
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
@@ -1976,13 +1983,6 @@ class Manager(utils.EventEmitter):
@property
def ecc_key(self) -> crypto.EccKey:
if self.debug_mode:
# Core - Vol 3, Part H:
# When the Security Manager is placed in a Debug mode it shall use the
# following Diffie-Hellman private / public key pair:
debug_key = crypto.EccKey.from_private_key_bytes(SMP_DEBUG_KEY_PRIVATE)
return debug_key
if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
@@ -2002,13 +2002,15 @@ class Manager(utils.EventEmitter):
def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection)
if pairing_config:
auth_req = AuthReq.from_booleans(
bonding=pairing_config.bonding,
sc=pairing_config.sc,
mitm=pairing_config.mitm,
auth_req = smp_auth_req(
pairing_config.bonding,
pairing_config.mitm,
pairing_config.sc,
False,
False,
)
else:
auth_req = AuthReq(0)
auth_req = 0
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session: Session) -> None:
@@ -2024,7 +2026,7 @@ class Manager(utils.EventEmitter):
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: ErrorCode) -> None:
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session: Session) -> None:

File diff suppressed because it is too large Load Diff

View File

@@ -133,10 +133,10 @@ def on_avrcp_start(
utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed() -> None:
async for uid in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", hex(uid))
async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex())
websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": hex(uid)}}
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
)
async def monitor_playback_status() -> None:

View File

@@ -83,7 +83,6 @@ async def main() -> None:
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
server_device.add_service(device_info_service)
await server_device.start_advertising()
# Connect the client to the server
connection = await client_device.connect(server_device.random_address)

View File

@@ -20,17 +20,110 @@ import contextlib
import functools
import json
import sys
import wave
import websockets.asyncio.server
import bumble.logging
from bumble import hci, hfp, rfcomm
from bumble.device import Connection, Device
from bumble.device import Connection, Device, ScoLink
from bumble.hfp import HfProtocol
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
ws: websockets.asyncio.server.ServerConnection | None = None
hf_protocol: HfProtocol | None = None
input_wav: wave.Wave_read | None = None
output_wav: wave.Wave_write | None = None
# -----------------------------------------------------------------------------
def on_audio_packet(packet: hci.HCI_SynchronousDataPacket) -> None:
if (
packet.packet_status
== hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA
):
if output_wav:
# Save the PCM audio to the output
output_wav.writeframes(packet.data)
else:
print('!!! discarding packet with status ', packet.packet_status.name)
if input_wav and hf_protocol:
# Send PCM audio from the input
frame_count = len(packet.data) // 2
while frame_count:
# NOTE: we use a fixed number of frames here, this should likely be adjusted
# based on the transport parameters (like the USB max packet size)
chunk_size = min(frame_count, 16)
if not (pcm_data := input_wav.readframes(chunk_size)):
return
frame_count -= chunk_size
hf_protocol.dlc.multiplexer.l2cap_channel.connection.device.host.send_sco_sdu(
connection_handle=packet.connection_handle,
sdu=pcm_data,
)
# -----------------------------------------------------------------------------
def on_sco_connection(link: ScoLink) -> None:
print('### SCO connection established:', link)
if link.air_mode == hci.CodecID.TRANSPARENT:
print("@@@ The controller does not encode/decode voice")
return
link.sink = on_audio_packet
# -----------------------------------------------------------------------------
def on_sco_request(
link_type: int, connection: Connection, protocol: HfProtocol
) -> None:
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.SCO_CVSD_D1]
elif protocol.active_codec == hfp.AudioCodec.MSBC:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_MSBC_T2]
elif protocol.active_codec == hfp.AudioCodec.CVSD:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S4]
else:
raise RuntimeError("unknown active codec")
if connection.device.host.supports_command(
hci.HCI_ENHANCED_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
):
connection.cancel_on_disconnection(
connection.device.send_async_command(
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address, **esco_parameters.asdict()
)
)
)
elif connection.device.host.supports_command(
hci.HCI_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
):
connection.cancel_on_disconnection(
connection.device.send_async_command(
hci.HCI_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address,
transmit_bandwidth=esco_parameters.transmit_bandwidth,
receive_bandwidth=esco_parameters.receive_bandwidth,
max_latency=esco_parameters.max_latency,
voice_setting=int(
hci.VoiceSetting(
input_sample_size=hci.VoiceSetting.InputSampleSize.SIZE_16_BITS,
input_data_format=hci.VoiceSetting.InputDataFormat.TWOS_COMPLEMENT,
)
),
retransmission_effort=esco_parameters.retransmission_effort,
packet_type=esco_parameters.packet_type,
)
)
)
else:
print('!!! no supported command for SCO connection request')
return
connection.on('sco_connection', on_sco_connection)
# -----------------------------------------------------------------------------
@@ -40,134 +133,172 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration):
hf_protocol = HfProtocol(dlc, configuration)
asyncio.create_task(hf_protocol.run())
def on_sco_request(connection: Connection, link_type: int, protocol: HfProtocol):
if connection == protocol.dlc.multiplexer.l2cap_channel.connection:
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.SCO_CVSD_D1
]
elif protocol.active_codec == hfp.AudioCodec.MSBC:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.ESCO_MSBC_T2
]
elif protocol.active_codec == hfp.AudioCodec.CVSD:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.ESCO_CVSD_S4
]
else:
raise RuntimeError("unknown active codec")
connection.cancel_on_disconnection(
connection.device.send_command(
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address, **esco_parameters.asdict()
)
)
)
handler = functools.partial(on_sco_request, protocol=hf_protocol)
dlc.multiplexer.l2cap_channel.connection.device.on('sco_request', handler)
connection = dlc.multiplexer.l2cap_channel.connection
handler = functools.partial(
on_sco_request,
connection=connection,
protocol=hf_protocol,
)
connection.on('sco_request', handler)
dlc.multiplexer.l2cap_channel.once(
'close',
lambda: dlc.multiplexer.l2cap_channel.connection.device.remove_listener(
'sco_request', handler
),
lambda: connection.remove_listener('sco_request', handler),
)
def on_ag_indicator(indicator):
global ws
if ws:
asyncio.create_task(ws.send(str(indicator)))
hf_protocol.on('ag_indicator', on_ag_indicator)
hf_protocol.on('codec_negotiation', on_codec_negotiation)
# -----------------------------------------------------------------------------
def on_ag_indicator(indicator):
global ws
if ws:
asyncio.create_task(ws.send(str(indicator)))
# -----------------------------------------------------------------------------
def on_codec_negotiation(codec: hfp.AudioCodec):
print(f'### Negotiated codec: {codec.name}')
global output_wav
if output_wav:
output_wav.setnchannels(1)
output_wav.setsampwidth(2)
match codec:
case hfp.AudioCodec.CVSD:
output_wav.setframerate(8000)
case hfp.AudioCodec.MSBC:
output_wav.setframerate(16000)
# -----------------------------------------------------------------------------
async def run(device: Device, codec: str | None) -> None:
if codec is None:
supported_audio_codecs = [hfp.AudioCodec.CVSD, hfp.AudioCodec.MSBC]
else:
if codec == 'cvsd':
supported_audio_codecs = [hfp.AudioCodec.CVSD]
elif codec == 'msbc':
supported_audio_codecs = [hfp.AudioCodec.MSBC]
else:
print('Unknown codec: ', codec)
return
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.HfConfiguration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=supported_audio_codecs,
)
# Create and register a server
rfcomm_server = rfcomm.Server(device)
# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}')
# Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = {
0x00010001: hfp.make_hf_sdp_records(0x00010001, channel_number, configuration)
}
# Let's go!
await device.power_on()
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Start the UI websocket server to offer a few buttons and input boxes
async def serve(websocket: websockets.asyncio.server.ServerConnection):
global ws
ws = websocket
async for message in websocket:
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'at_command':
if hf_protocol is not None:
response = str(
await hf_protocol.execute_command(
parsed['command'],
response_type=hfp.AtResponseType.MULTIPLE,
)
)
await websocket.send(response)
elif message_type == 'query_call':
if hf_protocol:
response = str(await hf_protocol.query_current_calls())
await websocket.send(response)
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await asyncio.get_running_loop().create_future() # run forever
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_classic_hfp.py <device-config> <transport-spec>')
print('example: run_classic_hfp.py classic2.json usb:04b4:f901')
print(
'Usage: run_hfp_handsfree.py <device-config> <transport-spec> '
'[codec] [input] [output]'
)
print('example: run_hfp_handsfree.py classic2.json usb:0')
return
print('<<< connecting to HCI...')
async with await open_transport(sys.argv[2]) as hci_transport:
print('<<< connected')
device_config = sys.argv[1]
transport_spec = sys.argv[2]
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.HfConfiguration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=[
hfp.AudioCodec.CVSD,
hfp.AudioCodec.MSBC,
],
)
codec: str | None = None
if len(sys.argv) >= 4:
codec = sys.argv[3]
# Create a device
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True
input_file_name: str | None = None
if len(sys.argv) >= 5:
input_file_name = sys.argv[4]
# Create and register a server
rfcomm_server = rfcomm.Server(device)
output_file_name: str | None = None
if len(sys.argv) >= 6:
output_file_name = sys.argv[5]
# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}')
global input_wav, output_wav
input_cm: contextlib.AbstractContextManager[wave.Wave_read | None] = (
wave.open(input_file_name, "rb")
if input_file_name
else contextlib.nullcontext(None)
)
output_cm: contextlib.AbstractContextManager[wave.Wave_write | None] = (
wave.open(output_file_name, "wb")
if output_file_name
else contextlib.nullcontext(None)
)
with input_cm as input_wav, output_cm as output_wav:
if input_wav and input_wav.getnchannels() != 1:
print("Mono input required")
return
if input_wav and input_wav.getsampwidth() != 2:
print("16-bit input required")
return
# Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = {
0x00010001: hfp.make_hf_sdp_records(
0x00010001, channel_number, configuration
async with await open_transport(transport_spec) as transport:
device = Device.from_config_file_with_hci(
device_config, transport.source, transport.sink
)
}
# Let's go!
await device.power_on()
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Start the UI websocket server to offer a few buttons and input boxes
async def serve(websocket: websockets.asyncio.server.ServerConnection):
global ws
ws = websocket
async for message in websocket:
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'at_command':
if hf_protocol is not None:
response = str(
await hf_protocol.execute_command(
parsed['command'],
response_type=hfp.AtResponseType.MULTIPLE,
)
)
await websocket.send(response)
elif message_type == 'query_call':
if hf_protocol:
response = str(await hf_protocol.query_current_calls())
await websocket.send(response)
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await hci_transport.source.terminated
device.classic_enabled = True
await run(device, codec)
# -----------------------------------------------------------------------------

View File

@@ -13,12 +13,13 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.10"
dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
"cryptography >= 39.0.0; platform_system=='Emscripten'",
"cryptography >= 44.0.3; platform_system=='Emscripten'",
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
# updated. Relax the version requirement since it's better than being completely unable
@@ -36,7 +37,7 @@ dependencies = [
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
"pyserial >= 3.5; platform_system!='Emscripten'",
"pyusb >= 1.2; platform_system!='Emscripten'",
"tomli ~= 2.2.1; platform_system!='Emscripten' and python_version<'3.11'",
"tomli ~= 2.2.1; platform_system!='Emscripten'",
"websockets >= 15.0.1; platform_system!='Emscripten'",
]

4
rust/Cargo.lock generated
View File

@@ -221,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
[[package]]
name = "bytes"
version = "1.11.1"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]]
name = "cc"

View File

@@ -30,7 +30,7 @@ hex = "0.4.3"
itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
bytes = "1.11.1"
bytes = "1.5.0"
pdl-derive = "0.2.0"
pdl-runtime = "0.2.0"
futures = "0.3.28"

View File

@@ -170,7 +170,9 @@ def format_code(ctx, check=False, diff=False):
@task
def check_types(ctx):
checklist = ["apps", "bumble", "examples", "tests", "tasks.py"]
print(">>> Running the type checker...")
try:
print("+++ Checking with mypy...")
ctx.run(f"mypy {' '.join(checklist)}")
except UnexpectedExit as exc:
print("Please check your code against the mypy messages.")

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import asyncio
import struct
from collections.abc import Sequence
from unittest import mock
import pytest
@@ -119,6 +118,8 @@ class TwoDevices(test_utils.TwoDevices):
scope=avrcp.Scope.NOW_PLAYING,
uid=0,
uid_counter=1,
start_item=0,
end_item=0,
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
),
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
@@ -135,7 +136,7 @@ def test_command(command: avrcp.Command):
"event,",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(uid=12356),
avrcp.TrackChangedEvent(identifier=b'12356'),
avrcp.VolumeChangedEvent(volume=9),
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
avrcp.AddressedPlayerChangedEvent(
@@ -580,87 +581,6 @@ async def test_get_supported_company_ids():
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_player_application_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: [
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.SINGLE_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.OFF,
],
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: [
avrcp.ApplicationSetting.ShuffleOnOffStatus.OFF,
avrcp.ApplicationSetting.ShuffleOnOffStatus.ALL_TRACKS_SHUFFLE,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
}
two_devices.protocols[1].delegate = avrcp.Delegate(
supported_player_app_settings=expected_settings
)
actual_settings = await two_devices.protocols[
0
].list_supported_player_app_settings()
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_set_player_app_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.SetPlayerApplicationSettingValueCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
],
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
),
)
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
}
assert delegate.player_app_settings == expected_settings
actual_settings = await two_devices.protocols[0].get_player_app_settings(
[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
)
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_play_item():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
with mock.patch.object(delegate, delegate.play_item.__name__) as play_item_mock:
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.PlayItemCommand(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
),
)
play_item_mock.assert_called_once_with(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_volume():
@@ -715,102 +635,6 @@ async def test_monitor_now_playing_content():
await anext(now_playing_iter)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_track_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.TRACK_CHANGED]
)
delegate.current_track_uid = avrcp.TrackChangedEvent.NO_TRACK
track_iter = two_devices.protocols[0].monitor_track_changed()
# Interim
assert (await anext(track_iter)) == avrcp.TrackChangedEvent.NO_TRACK
# Changed
two_devices.protocols[1].notify_track_changed(1)
assert (await anext(track_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_uid_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.UIDS_CHANGED]
)
delegate.uid_counter = 0
uid_iter = two_devices.protocols[0].monitor_uids()
# Interim
assert (await anext(uid_iter)) == 0
# Changed
two_devices.protocols[1].notify_uids_changed(1)
assert (await anext(uid_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_addressed_player():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.ADDRESSED_PLAYER_CHANGED]
)
delegate.uid_counter = 0
delegate.addressed_player_id = 0
addressed_player_iter = two_devices.protocols[0].monitor_addressed_player()
# Interim
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=0, uid_counter=0)
# Changed
two_devices.protocols[1].notify_addressed_player_changed(
avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
)
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_player_app_settings():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
supported_events=[avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED]
)
delegate.player_app_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
}
settings_iter = two_devices.protocols[0].monitor_player_application_settings()
# Interim
interim = await anext(settings_iter)
assert interim[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert (
interim[0].value_id
== avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
)
# Changed
two_devices.protocols[1].notify_player_application_settings_changed(
[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
)
]
)
changed = await anext(settings_iter)
assert changed[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert changed[0].value_id == avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frame_parser()

View File

@@ -73,14 +73,6 @@ def test_uuid_to_hex_str() -> None:
)
# -----------------------------------------------------------------------------
def test_uuid_hash() -> None:
uuid = UUID("1234")
uuid_128_bytes = UUID.from_bytes(uuid.to_bytes(force_128=True))
assert uuid in {uuid_128_bytes}
assert uuid_128_bytes in {uuid}
# -----------------------------------------------------------------------------
def test_appearance() -> None:
a = Appearance(Appearance.Category.COMPUTER, Appearance.ComputerSubcategory.LAPTOP)

View File

@@ -309,27 +309,6 @@ async def test_legacy_advertising_disconnection(auto_restart):
assert not devices[0].is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_le_multiple_connects():
devices = TwoDevices()
for controller in devices.controllers:
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
for dev in devices:
await dev.power_on()
await devices[0].start_advertising(auto_restart=True, advertising_interval_min=1.0)
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
await async_barrier()
await async_barrier()
# a second connection attempt is working
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_advertising_and_scanning():
@@ -466,9 +445,7 @@ async def test_get_remote_le_features():
devices = TwoDevices()
await devices.setup_connection()
assert (
await devices.connections[0].get_remote_le_features()
) == devices.controllers[1].le_features
assert (await devices.connections[0].get_remote_le_features()) is not None
# -----------------------------------------------------------------------------
@@ -826,22 +803,6 @@ async def test_remote_name_request():
assert actual_name == expected_name
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_remote_classic_features():
devices = TwoDevices()
devices[0].classic_enabled = True
devices[1].classic_enabled = True
await devices[0].power_on()
await devices[1].power_on()
connection = await devices[0].connect_classic(devices[1].public_address)
assert (
await asyncio.wait_for(connection.get_remote_classic_features(), _TIMEOUT)
== devices.controllers[1].lmp_features
)
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()

View File

@@ -22,7 +22,6 @@ import unittest.mock
import pytest
from bumble import controller, hci
from bumble.controller import Controller
from bumble.hci import (
HCI_AclDataPacket,
@@ -50,27 +49,34 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'supported_commands, max_lmp_features_page_number',
'supported_commands, lmp_features',
[
(controller.Controller.supported_commands, 0),
(
# Default commands
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000',
# Only LE LMP feature
'0000000060000000',
),
(
# All commands
set(hci.HCI_Command.command_names.keys()),
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
# 3 pages of LMP features
2,
'000102030405060708090A0B0C0D0E0F011112131415161718191A1B1C1D1E1F',
),
],
)
async def test_reset(supported_commands: set[int], max_lmp_features_page_number: int):
async def test_reset(supported_commands: str, lmp_features: str):
controller = Controller('C')
controller.supported_commands = supported_commands
controller.lmp_features_max_page_number = max_lmp_features_page_number
controller.supported_commands = bytes.fromhex(supported_commands)
controller.lmp_features = bytes.fromhex(lmp_features)
host = Host(controller, AsyncPipeSink(controller))
await host.reset()
assert host.local_lmp_features == (
controller.lmp_features & ~(1 << (64 * max_lmp_features_page_number + 1))
assert host.local_lmp_features == int.from_bytes(
bytes.fromhex(lmp_features), 'little'
)
@@ -171,15 +177,14 @@ class Source:
class Sink:
response: HCI_Event | None
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event | None) -> None:
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
if self.response is not None:
self.source.sink.on_packet(bytes(self.response))
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
@@ -229,23 +234,6 @@ async def test_send_sync_command() -> None:
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_sync_command_timeout() -> None:
source = Source()
sink = Sink(source, None)
host = Host(source, sink)
host.ready = True
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
# The sending semaphore should have been released, so this should not block
# indefinitely
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()

View File

@@ -21,7 +21,6 @@ import logging
import os
import pathlib
import tempfile
from unittest import mock
import pytest
@@ -180,55 +179,11 @@ async def test_default_namespace(temporary_file):
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_filename(tmp_path):
import platformdirs
with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path):
# Case 1: no namespace, no filename
keystore = JsonKeyStore(None, None)
expected_directory = tmp_path / 'Pairing'
expected_filename = expected_directory / 'keys.json'
assert keystore.directory_name == expected_directory
assert keystore.filename == expected_filename
# Save some data
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
assert expected_filename.exists()
# Load back
keystore2 = JsonKeyStore(None, None)
foo = await keystore2.get('foo')
assert foo is not None
assert foo.ltk.value == ltk
# Case 2: namespace, no filename
keystore3 = JsonKeyStore('my:namespace', None)
# safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-')
expected_filename3 = expected_directory / 'my-namespace.json'
assert keystore3.filename == expected_filename3
# Save some data
await keystore3.update('bar', keys)
assert expected_filename3.exists()
# Load back
keystore4 = JsonKeyStore('my:namespace', None)
bar = await keystore4.get('bar')
assert bar is not None
assert bar.ltk.value == ltk
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
await test_no_filename()
# -----------------------------------------------------------------------------

View File

@@ -29,7 +29,8 @@ from bumble.gatt import Characteristic, Service
from bumble.hci import Role
from bumble.pairing import PairingConfig, PairingDelegate
from bumble.smp import (
ErrorCode,
SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_PAIRING_NOT_SUPPORTED_ERROR,
OobContext,
OobLegacyContext,
)
@@ -377,7 +378,7 @@ async def test_self_smp_reject():
await _test_self_smp_with_configs(None, rejecting_pairing_config)
paired = True
except ProtocolError as error:
assert error.error_code == ErrorCode.PAIRING_NOT_SUPPORTED
assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
assert not paired
@@ -402,7 +403,7 @@ async def test_self_smp_wrong_pin():
)
paired = True
except ProtocolError as error:
assert error.error_code == ErrorCode.CONFIRM_VALUE_FAILED
assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert not paired
@@ -533,11 +534,11 @@ async def test_self_smp_oob_sc():
with pytest.raises(ProtocolError) as error:
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
with pytest.raises(ProtocolError):
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
# -----------------------------------------------------------------------------

View File

@@ -24,7 +24,7 @@ import pytest
from bumble import crypto, pairing, smp
from bumble.core import AdvertisingData
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.device import Device, DeviceConfiguration
from bumble.device import Device
from bumble.hci import Address
from bumble.pairing import LeRole, OobData, OobSharedData
@@ -312,17 +312,3 @@ async def test_send_identity_address_command(
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
@pytest.mark.asyncio
async def test_smp_debug_mode():
config = DeviceConfiguration(smp_debug_mode=True)
device = Device(config=config)
assert device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
device.smp_manager.debug_mode = False
assert not device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert not device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y

View File

@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" />
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="heart_rate_monitor.js"></script>
<style>

View File

@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="scanner.css">
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="scanner.js"></script>
</style>

View File

@@ -4,7 +4,7 @@
<title>Bumble Speaker</title>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="speaker.css">
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script type="module" src="speaker.js"></script>
<script type="module" src="../ui.js"></script>
</head>