Compare commits

...

59 Commits

Author SHA1 Message Date
dependabot[bot]
834c8acd85 Bump rand in /rust in the cargo group across 1 directory
Bumps the cargo group with 1 update in the /rust directory: [rand](https://github.com/rust-random/rand).


Updates `rand` from 0.8.5 to 0.9.3
- [Release notes](https://github.com/rust-random/rand/releases)
- [Changelog](https://github.com/rust-random/rand/blob/0.9.3/CHANGELOG.md)
- [Commits](https://github.com/rust-random/rand/compare/0.8.5...0.9.3)

---
updated-dependencies:
- dependency-name: rand
  dependency-version: 0.9.3
  dependency-type: direct:production
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 02:05:17 +00:00
Josh Wu
f2824ee6b8 Merge pull request #907 from zxzxwu/example-gatt-client-and-server
Advertise in run_gatt_client_and_server
2026-04-13 16:31:19 +08:00
Josh Wu
7188ef08de Advertise in run_gatt_client_and_server 2026-04-13 15:31:32 +08:00
Josh Wu
3ded9014d3 Merge pull request #905 from markusjellitsch/feature/debug-keys
Feature  - Add SMP Debug Mode  (Core Vol.3, Part H)
2026-04-09 15:36:42 +08:00
Josh Wu
b6125bdfb1 Merge pull request #904 from zxzxwu/keys
Keys: Remove appdirs and improve typing
2026-04-09 15:30:39 +08:00
Markus Jellitsch
dc17f4f1ca remove asserts 2026-04-08 20:58:47 +02:00
Markus Jellitsch
3f65380c20 remove comment 2026-04-03 23:19:43 +02:00
Markus Jellitsch
25a0056ecc remove uncommented line 2026-04-03 23:08:16 +02:00
Markus Jellitsch
85f6b10983 run formatter 2026-04-03 23:06:24 +02:00
Markus Jellitsch
e85f041e9d add test for smp debug mode 2026-04-03 23:04:48 +02:00
Markus Jellitsch
ee09e6f10d add smp_debug_mode config flag to enable debug keys during device init 2026-04-03 23:03:51 +02:00
Markus Jellitsch
c3daf4a7e1 implement debug mode for smp manager using defined private / public key pair 2026-04-03 23:02:15 +02:00
Josh Wu
3af623be7e Keys: Remove appdirs and improve typing 2026-03-31 16:25:15 +08:00
Gilles Boccon-Gibod
4e76d3057b Merge pull request #903 from sameer/micropip-install-crypto-issue
Fix Hive demo install failure
2026-03-28 15:35:32 -04:00
Sameer Puri
eda7360222 Upgrade pyodide in web fixes import error
Prior to this, these web pages fail to load with
`ImportError: cannot import name 'TypeIs' from 'typing_extensions'
(/lib/python3.11/site-packages/typing_extensions.py)`
2026-03-26 18:39:07 +00:00
Sameer Puri
a4c15c00de Downgrade cryptography, fixes micropip failure
Prior to this, these web pages fail to load with

`ValueError: Can't find a pure Python 3 wheel for 'cryptography>=44.0.3;
platform_system == "Emscripten"'.`
2026-03-26 18:38:12 +00:00
Josh Wu
cba4df4aef Merge pull request #900 from zxzxwu/lmp-feat
Add read classic remote features support
2026-03-24 14:03:29 +08:00
Josh Wu
ceb8b448e9 Merge pull request #901 from zxzxwu/rust
Add --locked to allow installing cargo-all-features
2026-03-21 03:45:47 +08:00
Josh Wu
311b716d5c Add --locked to allow installing cargo-all-features 2026-03-20 18:44:49 +08:00
Josh Wu
0ba9e5c317 Add read classic remote features support 2026-03-20 18:32:52 +08:00
Josh Wu
3517225b62 Merge pull request #898 from zxzxwu/phy
Make ConnectionPHY dataclass
2026-03-13 12:04:45 +08:00
Josh Wu
ad4bb1578b Make ConnectionPHY dataclass 2026-03-11 21:41:48 +08:00
Josh Wu
4af65b381b Merge pull request #820 from zxzxwu/sdp
SDP: Migrate to dataclasses
2026-03-04 13:45:39 +08:00
Josh Wu
a5cd3365ae Merge pull request #895 from zxzxwu/uuid
Hash and cache 128 bytes of UUID
2026-03-04 00:29:43 +08:00
Josh Wu
2915cb8bb6 Add test for UUID hash 2026-03-04 00:22:50 +08:00
Josh Wu
28e485b7b3 Hash and cache 128 bytes of UUID 2026-03-03 17:54:27 +08:00
Josh Wu
1198f2c3f5 SDP: Make PDU dataclasses 2026-03-03 02:07:08 +08:00
Josh Wu
80aaf6a2b9 SDP: Make DataElement and ServiceAttribute dataclasses 2026-03-03 01:28:40 +08:00
Josh Wu
eb64debb62 Merge pull request #893 from zxzxwu/le-emu
Emulation: Support LE Read features
2026-03-01 17:01:11 +08:00
Josh Wu
c158f25b1e Emulation: Support LE Read features 2026-03-01 02:24:55 +08:00
Josh Wu
1330e83517 Merge pull request #892 from zxzxwu/hfp
HFP: Fix response handling
2026-02-26 13:18:03 +08:00
Josh Wu
d9c9bea6cb HFP: Fix response handling 2026-02-25 00:39:45 +08:00
Gilles Boccon-Gibod
3b937631b3 Merge pull request #891 from a-detiste/main 2026-02-18 21:13:09 -08:00
Alexandre Detiste
f8aa309111 fix pyproject.toml format 2026-02-18 16:39:09 +01:00
Alexandre Detiste
673281ed71 use tomllib from standard library on Python3.11+ 2026-02-18 11:11:49 +01:00
Josh Wu
3ac7af4683 Merge pull request #886 from zxzxwu/controller-status
Controller: Use new return parameter types and add _send_hci_command_status
2026-02-11 13:27:32 +08:00
Josh Wu
5ebfaae74e Controller: Use new return parameter types and add _send_hci_command_status() 2026-02-11 13:21:47 +08:00
Josh Wu
e6175f85fe Merge pull request #887 from zxzxwu/gap
Remove bumble.gap
2026-02-11 13:15:39 +08:00
Josh Wu
f9ba527508 Merge pull request #821 from zxzxwu/smp
Migrate most enums
2026-02-11 13:15:22 +08:00
Josh Wu
a407c4cabf Merge pull request #883 from zxzxwu/avrcp
AVRCP: More delegation and bugfix
2026-02-11 13:13:16 +08:00
Josh Wu
6c2d6dddb5 Merge pull request #885 from zxzxwu/match-case
Replace long if-else with match-case
2026-02-11 13:12:38 +08:00
Josh Wu
797cd216d4 SMP: Migrate all enums 2026-02-10 20:08:01 +08:00
Josh Wu
e2e8c90e47 Remove bumble.gap 2026-02-10 17:40:22 +08:00
Josh Wu
3d5648cdc3 Replace long if-else with match-case 2026-02-10 17:35:39 +08:00
Josh Wu
d810d93aaf Merge pull request #884 from timrid/fix-multiple-le-connections
Connecting multiple times to a LE device is working correctly again
2026-02-06 11:25:44 +08:00
timrid
81d9adb983 delete only the required connection 2026-02-05 20:50:58 +01:00
Josh Wu
377fa896f7 Merge pull request #881 from google/dependabot/cargo/rust/cargo-f6ecf5c85a
Bump bytes from 1.5.0 to 1.11.1 in /rust in the cargo group across 1 directory
2026-02-05 23:55:37 +08:00
timrid
79e5974946 Multiple le connections are now working correctly 2026-02-05 13:15:57 +01:00
Josh Wu
657451474e AVRCP: Address type errors 2026-02-05 16:01:21 +08:00
Josh Wu
9f730dce6f AVRCP: Delegate Track Changed 2026-02-05 15:50:06 +08:00
Josh Wu
1a6be95a7e AVRCP: Delegate UID and Addressed Player 2026-02-05 15:44:11 +08:00
Josh Wu
aea5320d71 AVRCP: Add Play Item delegation 2026-02-05 15:34:03 +08:00
Josh Wu
91cb1b1df3 AVRCP: Add available player changed event 2026-02-05 15:25:17 +08:00
Josh Wu
81bdc86e52 AVRCP: Delegate Player App Settings 2026-02-05 15:22:11 +08:00
Josh Wu
f23cad34e3 AVRCP: Use match-case 2026-02-04 22:23:53 +08:00
Josh Wu
30fde2c00b AVRCP: Fix wrong packet field specs 2026-02-04 18:05:25 +08:00
Josh Wu
256a1a7405 Merge pull request #882 from zxzxwu/hci
Fix wrong LE event codes
2026-02-04 17:40:54 +08:00
Josh Wu
116d9b26bb Fix wrong LE event codes 2026-02-04 15:03:08 +08:00
dependabot[bot]
aabe2ca063 Bump bytes in /rust in the cargo group across 1 directory
Bumps the cargo group with 1 update in the /rust directory: [bytes](https://github.com/tokio-rs/bytes).


Updates `bytes` from 1.5.0 to 1.11.1
- [Release notes](https://github.com/tokio-rs/bytes/releases)
- [Changelog](https://github.com/tokio-rs/bytes/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tokio-rs/bytes/compare/v1.5.0...v1.11.1)

---
updated-dependencies:
- dependency-name: bytes
  dependency-version: 1.11.1
  dependency-type: direct:production
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-03 20:46:35 +00:00
39 changed files with 2560 additions and 1837 deletions

View File

@@ -69,7 +69,7 @@ jobs:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
run: cargo install cargo-all-features --version 1.11.0 --locked # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build

View File

@@ -24,13 +24,18 @@ import dataclasses
import functools
import logging
import secrets
import sys
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from typing import (
Any,
)
import click
import tomli
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
try:
import lc3 # type: ignore # pylint: disable=E0401
@@ -114,7 +119,7 @@ def parse_broadcast_list(filename: str) -> Sequence[Broadcast]:
broadcasts: list[Broadcast] = []
with open(filename, "rb") as config_file:
config = tomli.load(config_file)
config = tomllib.load(config_file)
for broadcast in config.get("broadcasts", []):
sources = []
for source in broadcast.get("sources", []):

View File

@@ -20,11 +20,12 @@ from __future__ import annotations
import asyncio
import logging
import os
from typing import ClassVar
import click
from prompt_toolkit.shortcuts import PromptSession
from bumble import data_types
from bumble import data_types, smp
from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
@@ -40,7 +41,7 @@ from bumble.core import (
PhysicalTransport,
ProtocolError,
)
from bumble.device import Device, Peer
from bumble.device import Connection, Device, Peer
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -53,7 +54,6 @@ from bumble.hci import OwnAddressType
from bumble.keys import JsonKeyStore
from bumble.pairing import OobData, PairingConfig, PairingDelegate
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -65,7 +65,7 @@ POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance: Waiter | None = None
instance: ClassVar[Waiter | None] = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
@@ -319,12 +319,13 @@ async def on_classic_pairing(connection):
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason):
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color(f'*** Pairing failed: {reason.name}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
Waiter.instance.terminate()
if Waiter.instance:
Waiter.instance.terminate()
# -----------------------------------------------------------------------------

View File

@@ -88,13 +88,6 @@ SBC_DUAL_CHANNEL_MODE = 0x01
SBC_STEREO_CHANNEL_MODE = 0x02
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
SBC_CHANNEL_MODE_NAMES = {
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
}
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
SBC_SUBBANDS = [4, 8]
@@ -102,11 +95,6 @@ SBC_SUBBANDS = [4, 8]
SBC_SNR_ALLOCATION_METHOD = 0x00
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
SBC_ALLOCATION_METHOD_NAMES = {
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
@@ -129,13 +117,6 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
@@ -267,26 +248,27 @@ class MediaCodecInformation:
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
match media_codec_type:
case CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
case CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
case CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod

View File

@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string."""
tokens = []
tokens: list[bytearray] = []
in_quotes = False
token = bytearray()
for b in buffer:
@@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
match char:
case b' ':
pass
case b',' | b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
case b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
case b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
case _:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
@@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
current: bytes | list = b''
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = b''
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
match token:
case b',':
accumulator[-1].append(current)
current = b''
case b'(':
accumulator.append([])
case b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
case _:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:

View File

@@ -954,12 +954,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
match attribute_type:
case str():
self.type = UUID(attribute_type)
case bytes():
self.type = UUID.from_bytes(attribute_type)
case _:
self.type = attribute_type
self.value = value
@@ -994,30 +995,31 @@ class Attribute(utils.EventEmitter, Generic[_T]):
)
value: _T | None
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
match self.value:
case AttributeValue():
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
value = self.value
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
@@ -1049,26 +1051,27 @@ class Attribute(utils.EventEmitter, Generic[_T]):
decoded_value = self.decode_value(value)
if isinstance(self.value, AttributeValue):
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value
match self.value:
case AttributeValue():
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
self.value = decoded_value
self.emit(self.EVENT_WRITE, connection, decoded_value)

View File

@@ -22,7 +22,14 @@ import enum
import functools
import logging
import struct
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterable,
Mapping,
Sequence,
)
from dataclasses import dataclass, field
from typing import ClassVar, SupportsBytes, TypeVar
@@ -1049,11 +1056,9 @@ class GetItemAttributesCommand(Command):
scope: Scope = field(metadata=Scope.type_metadata(1))
uid: int = field(metadata=_UINT64_BE_METADATA)
uid_counter: int = field(metadata=hci.metadata('>2'))
start_item: int = field(metadata=hci.metadata('>4'))
end_item: int = field(metadata=hci.metadata('>4'))
# When attributes is empty, all attributes will be requested.
attributes: Sequence[MediaAttributeId] = field(
metadata=MediaAttributeId.type_metadata(1, list_begin=True, list_end=True)
metadata=MediaAttributeId.type_metadata(4, list_begin=True, list_end=True)
)
@@ -1512,7 +1517,9 @@ class PlaybackPositionChangedEvent(Event):
@dataclass
class TrackChangedEvent(Event):
event_id = EventId.TRACK_CHANGED
identifier: bytes = field(metadata=hci.metadata('*'))
NO_TRACK = 0xFFFFFFFFFFFFFFFF
uid: int = field(metadata=_UINT64_BE_METADATA)
# -----------------------------------------------------------------------------
@@ -1536,16 +1543,19 @@ class PlayerApplicationSettingChangedEvent(Event):
def __post_init__(self) -> None:
super().__post_init__()
if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
else:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
match self.attribute_id:
case ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(
self.value_id
)
case ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
case ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
case ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
case _:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
player_application_settings: Sequence[Setting] = field(
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
@@ -1619,6 +1629,8 @@ class Delegate:
supported_events: list[EventId]
supported_company_ids: list[int]
supported_player_app_settings: dict[ApplicationSetting.AttributeId, list[int]]
player_app_settings: dict[ApplicationSetting.AttributeId, int]
volume: int
playback_status: PlayStatus
@@ -1626,11 +1638,23 @@ class Delegate:
self,
supported_events: Iterable[EventId] = (),
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
supported_player_app_settings: (
Mapping[ApplicationSetting.AttributeId, Sequence[int]] | None
) = None,
) -> None:
self.supported_company_ids = list(supported_company_ids)
self.supported_events = list(supported_events)
self.volume = 0
self.playback_status = PlayStatus.STOPPED
self.supported_player_app_settings = (
{key: list(value) for key, value in supported_player_app_settings.items()}
if supported_player_app_settings
else {}
)
self.player_app_settings = {}
self.uid_counter = 0
self.addressed_player_id = 0
self.current_track_uid = TrackChangedEvent.NO_TRACK
async def get_supported_events(self) -> list[EventId]:
return self.supported_events
@@ -1663,6 +1687,38 @@ class Delegate:
async def get_playback_status(self) -> PlayStatus:
return self.playback_status
async def get_supported_player_app_settings(
self,
) -> dict[ApplicationSetting.AttributeId, list[int]]:
return self.supported_player_app_settings
async def get_current_player_app_settings(
self,
) -> dict[ApplicationSetting.AttributeId, int]:
return self.player_app_settings
async def set_player_app_settings(
self, attribute: ApplicationSetting.AttributeId, value: int
) -> None:
self.player_app_settings[attribute] = value
async def play_item(self, scope: Scope, uid: int, uid_counter: int) -> None:
logger.debug(
"@@@ play_item: scope=%s, uid=%s, uid_counter=%s",
scope,
uid,
uid_counter,
)
async def get_uid_counter(self) -> int:
return self.uid_counter
async def get_addressed_player_id(self) -> int:
return self.addressed_player_id
async def get_current_track_uid(self) -> int:
return self.current_track_uid
# TODO add other delegate methods
@@ -1910,6 +1966,51 @@ class Protocol(utils.EventEmitter):
response = self._check_response(response_context, GetElementAttributesResponse)
return list(response.attributes)
async def list_supported_player_app_settings(
self, attribute_ids: Sequence[ApplicationSetting.AttributeId] = ()
) -> dict[ApplicationSetting.AttributeId, list[int]]:
"""Get element attributes from the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
ListPlayerApplicationSettingAttributesCommand(),
)
if not attribute_ids:
list_attribute_response = self._check_response(
response_context, ListPlayerApplicationSettingAttributesResponse
)
attribute_ids = list_attribute_response.attribute
supported_settings: dict[ApplicationSetting.AttributeId, list[int]] = {}
for attribute_id in attribute_ids:
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
ListPlayerApplicationSettingValuesCommand(attribute_id),
)
list_value_response = self._check_response(
response_context, ListPlayerApplicationSettingValuesResponse
)
supported_settings[attribute_id] = list(list_value_response.value)
return supported_settings
async def get_player_app_settings(
self, attribute_ids: Sequence[ApplicationSetting.AttributeId]
) -> dict[ApplicationSetting.AttributeId, int]:
"""Get element attributes from the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
GetCurrentPlayerApplicationSettingValueCommand(attribute_ids),
)
response: GetCurrentPlayerApplicationSettingValueResponse = (
self._check_response(
response_context, GetCurrentPlayerApplicationSettingValueResponse
)
)
return {
attribute_id: value
for attribute_id, value in zip(response.attribute, response.value)
}
async def monitor_events(
self, event_id: EventId, playback_interval: int = 0
) -> AsyncIterator[Event]:
@@ -1961,13 +2062,13 @@ class Protocol(utils.EventEmitter):
async def monitor_track_changed(
self,
) -> AsyncIterator[bytes]:
) -> AsyncIterator[int]:
"""Monitor Track changes from the connected peer."""
async for event in self.monitor_events(EventId.TRACK_CHANGED, 0):
if not isinstance(event, TrackChangedEvent):
logger.warning("unexpected event class")
continue
yield event.identifier
yield event.uid
async def monitor_playback_position(
self, playback_interval: int
@@ -2060,11 +2161,9 @@ class Protocol(utils.EventEmitter):
"""Notify the connected peer of a Playback Status change."""
self.notify_event(PlaybackStatusChangedEvent(status))
def notify_track_changed(self, identifier: bytes) -> None:
def notify_track_changed(self, uid: int) -> None:
"""Notify the connected peer of a Track change."""
if len(identifier) != 8:
raise core.InvalidArgumentError("identifier must be 8 bytes")
self.notify_event(TrackChangedEvent(identifier))
self.notify_event(TrackChangedEvent(uid))
def notify_playback_position_changed(self, position: int) -> None:
"""Notify the connected peer of a Position change."""
@@ -2280,21 +2379,40 @@ class Protocol(utils.EventEmitter):
):
# TODO: catch exceptions from delegates
command = Command.from_bytes(pdu_id, pdu)
if isinstance(command, GetCapabilitiesCommand):
self._on_get_capabilities_command(transaction_label, command)
elif isinstance(command, SetAbsoluteVolumeCommand):
self._on_set_absolute_volume_command(transaction_label, command)
elif isinstance(command, RegisterNotificationCommand):
self._on_register_notification_command(transaction_label, command)
elif isinstance(command, GetPlayStatusCommand):
self._on_get_play_status_command(transaction_label, command)
else:
# Not supported.
# TODO: check that this is the right way to respond in this case.
logger.debug("unsupported PDU ID")
self.send_rejected_avrcp_response(
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
)
match command:
case GetCapabilitiesCommand():
self._on_get_capabilities_command(transaction_label, command)
case SetAbsoluteVolumeCommand():
self._on_set_absolute_volume_command(transaction_label, command)
case RegisterNotificationCommand():
self._on_register_notification_command(transaction_label, command)
case GetPlayStatusCommand():
self._on_get_play_status_command(transaction_label, command)
case ListPlayerApplicationSettingAttributesCommand():
self._on_list_player_application_setting_attributes_command(
transaction_label, command
)
case ListPlayerApplicationSettingValuesCommand():
self._on_list_player_application_setting_values_command(
transaction_label, command
)
case SetPlayerApplicationSettingValueCommand():
self._on_set_player_application_setting_value_command(
transaction_label, command
)
case GetCurrentPlayerApplicationSettingValueCommand():
self._on_get_current_player_application_setting_value_command(
transaction_label, command
)
case PlayItemCommand():
self._on_play_item_command(transaction_label, command)
case _:
# Not supported.
# TODO: check that this is the right way to respond in this case.
logger.debug("unsupported PDU ID")
self.send_rejected_avrcp_response(
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
)
else:
logger.debug("unsupported command type")
self.send_rejected_avrcp_response(
@@ -2322,26 +2440,29 @@ class Protocol(utils.EventEmitter):
# is Ok, but if/when more responses are supported, a lookup mechanism would be
# more appropriate.
response: Response | None = None
if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(pdu_id=pdu_id, status_code=StatusCode(pdu[0]))
elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
elif response_code in (
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
avc.ResponseFrame.ResponseCode.INTERIM,
avc.ResponseFrame.ResponseCode.CHANGED,
avc.ResponseFrame.ResponseCode.ACCEPTED,
):
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
else:
logger.debug("unexpected response code")
pending_command.response.set_exception(
core.ProtocolError(
error_code=None,
error_namespace="avrcp",
details="unexpected response code",
match response_code:
case avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(
pdu_id=pdu_id, status_code=StatusCode(pdu[0])
)
case avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
case (
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE
| avc.ResponseFrame.ResponseCode.INTERIM
| avc.ResponseFrame.ResponseCode.CHANGED
| avc.ResponseFrame.ResponseCode.ACCEPTED
):
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
case _:
logger.debug("unexpected response code")
pending_command.response.set_exception(
core.ProtocolError(
error_code=None,
error_namespace="avrcp",
details="unexpected response code",
)
)
)
if response is None:
self.recycle_pending_command(pending_command)
@@ -2512,22 +2633,18 @@ class Protocol(utils.EventEmitter):
async def get_supported_events() -> None:
capabilities: Sequence[bytes | SupportsBytes]
if (
command.capability_id
== GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
):
capabilities = await self.delegate.get_supported_events()
elif (
command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID
):
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
else:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
match command.capability_id:
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
capabilities = await self.delegate.get_supported_events()
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED.COMPANY_ID:
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
case _:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
@@ -2572,6 +2689,121 @@ class Protocol(utils.EventEmitter):
self._delegate_command(transaction_label, command, get_playback_status())
def _on_list_player_application_setting_attributes_command(
self,
transaction_label: int,
command: ListPlayerApplicationSettingAttributesCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
supported_settings = await self.delegate.get_supported_player_app_settings()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
ListPlayerApplicationSettingAttributesResponse(
list(supported_settings.keys())
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_list_player_application_setting_values_command(
self,
transaction_label: int,
command: ListPlayerApplicationSettingValuesCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
supported_settings = await self.delegate.get_supported_player_app_settings()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
ListPlayerApplicationSettingValuesResponse(
supported_settings.get(command.attribute, [])
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_get_current_player_application_setting_value_command(
self,
transaction_label: int,
command: GetCurrentPlayerApplicationSettingValueCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_supported_player_app_settings() -> None:
current_settings = await self.delegate.get_current_player_app_settings()
if not all(
attribute in current_settings for attribute in command.attribute
):
self.send_not_implemented_avrcp_response(
transaction_label,
PduId.GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE,
)
return
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetCurrentPlayerApplicationSettingValueResponse(
attribute=command.attribute,
value=[
current_settings[attribute] for attribute in command.attribute
],
),
)
self._delegate_command(
transaction_label, command, get_supported_player_app_settings()
)
def _on_set_player_application_setting_value_command(
self,
transaction_label: int,
command: SetPlayerApplicationSettingValueCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def set_player_app_settings() -> None:
for attribute, value in zip(command.attribute, command.value):
await self.delegate.set_player_app_settings(attribute, value)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
SetPlayerApplicationSettingValueResponse(),
)
self._delegate_command(transaction_label, command, set_player_app_settings())
def _on_play_item_command(
self,
transaction_label: int,
command: PlayItemCommand,
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def play_item() -> None:
await self.delegate.play_item(
scope=command.scope, uid=command.uid, uid_counter=command.uid_counter
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
PlayItemResponse(status=StatusCode.OPERATION_COMPLETED),
)
self._delegate_command(transaction_label, command, play_item())
def _on_register_notification_command(
self, transaction_label: int, command: RegisterNotificationCommand
) -> None:
@@ -2587,26 +2819,51 @@ class Protocol(utils.EventEmitter):
)
return
response: Response
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
elif command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=playback_status)
)
elif command.event_id == EventId.NOW_PLAYING_CONTENT_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(NowPlayingContentChangedEvent())
else:
logger.warning("Event supported but not handled %s", command.event_id)
return
event: Event
match command.event_id:
case EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
event = VolumeChangedEvent(volume)
case EventId.PLAYBACK_STATUS_CHANGED:
playback_status = await self.delegate.get_playback_status()
event = PlaybackStatusChangedEvent(play_status=playback_status)
case EventId.NOW_PLAYING_CONTENT_CHANGED:
event = NowPlayingContentChangedEvent()
case EventId.PLAYER_APPLICATION_SETTING_CHANGED:
settings = await self.delegate.get_current_player_app_settings()
event = PlayerApplicationSettingChangedEvent(
[
PlayerApplicationSettingChangedEvent.Setting(
attribute, value # type: ignore
)
for attribute, value in settings.items()
]
)
case EventId.AVAILABLE_PLAYERS_CHANGED:
event = AvailablePlayersChangedEvent()
case EventId.ADDRESSED_PLAYER_CHANGED:
event = AddressedPlayerChangedEvent(
AddressedPlayerChangedEvent.Player(
player_id=await self.delegate.get_addressed_player_id(),
uid_counter=await self.delegate.get_uid_counter(),
)
)
case EventId.UIDS_CHANGED:
event = UidsChangedEvent(await self.delegate.get_uid_counter())
case EventId.TRACK_CHANGED:
event = TrackChangedEvent(
await self.delegate.get_current_track_uid()
)
case _:
logger.warning(
"Event supported but not handled %s", command.event_id
)
return
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
RegisterNotificationResponse(event),
)
self._register_notification_listener(transaction_label, command)

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import dataclasses
import enum
import functools
import struct
from collections.abc import Iterable
from typing import (
@@ -273,6 +274,18 @@ class UUID:
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
@functools.cached_property
def uuid_128_bytes(self) -> bytes:
match len(self.uuid_bytes):
case 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
case 4:
return self.BASE_UUID + self.uuid_bytes
case 16:
return self.uuid_bytes
case _:
assert False, "unreachable"
def to_bytes(self, force_128: bool = False) -> bytes:
'''
Serialize UUID in little-endian byte-order
@@ -280,14 +293,7 @@ class UUID:
if not force_128:
return self.uuid_bytes
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
else:
assert False, "unreachable"
return self.uuid_128_bytes
def to_pdu_bytes(self) -> bytes:
'''
@@ -317,7 +323,7 @@ class UUID:
def __eq__(self, other: object) -> bool:
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
return self.uuid_128_bytes == other.uuid_128_bytes
if isinstance(other, str):
return UUID(other) == self
@@ -325,7 +331,7 @@ class UUID:
return False
def __hash__(self) -> int:
return hash(self.uuid_bytes)
return hash(self.uuid_128_bytes)
def __str__(self) -> str:
result = self.to_hex_str(separator='-')
@@ -1769,66 +1775,71 @@ class AdvertisingData:
@classmethod
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
if ad_type == AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
match ad_type:
case AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
case AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
case AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
case AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
)
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
case AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(
struct.unpack_from('<H', ad_data, 0)[0]
)
ad_data_str = str(appearance)
case AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
case _:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
ad_data_str = str(appearance)
elif ad_type == AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
else:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}'
@@ -2105,13 +2116,10 @@ class AdvertisingData:
# -----------------------------------------------------------------------------
# Connection PHY
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ConnectionPHY:
def __init__(self, tx_phy, rx_phy):
self.tx_phy = tx_phy
self.rx_phy = rx_phy
def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
tx_phy: int
rx_phy: int
# -----------------------------------------------------------------------------

View File

@@ -1837,6 +1837,7 @@ class Connection(utils.CompositeEventEmitter):
self.pairing_peer_io_capability = None
self.pairing_peer_authentication_requirements = None
self.peer_le_features = hci.LeFeatureMask(0)
self.peer_classic_features = hci.LmpFeatureMask(0)
self.cs_configs = {}
self.cs_procedures = {}
@@ -2054,6 +2055,15 @@ class Connection(utils.CompositeEventEmitter):
self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features
async def get_remote_classic_features(self) -> hci.LmpFeatureMask:
"""[Classic Only] Reads remote LMP supported features.
Returns:
LMP features supported by the remote device.
"""
self.peer_classic_features = await self.device.get_remote_classic_features(self)
return self.peer_classic_features
def on_att_mtu_update(self, mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
@@ -2149,6 +2159,7 @@ class DeviceConfiguration:
)
eatt_enabled: bool = False
gatt_services: list[dict[str, Any]] = field(init=False)
smp_debug_mode: bool = False
def __post_init__(self) -> None:
self.gatt_services = []
@@ -2561,6 +2572,7 @@ class Device(utils.CompositeEventEmitter):
),
),
)
self.smp_manager.debug_mode = self.config.smp_debug_mode
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -5281,6 +5293,77 @@ class Device(utils.CompositeEventEmitter):
)
return await read_feature_future
async def get_remote_classic_features(
self, connection: Connection
) -> hci.LmpFeatureMask:
"""[Classic Only] Reads remote LE supported features.
Args:
handle: connection handle to read LMP features.
Returns:
LMP features supported by the remote device.
"""
with closing(utils.EventWatcher()) as watcher:
read_feature_future: asyncio.Future[tuple[int, int]] = (
asyncio.get_running_loop().create_future()
)
read_features = hci.LmpFeatureMask(0)
current_page_number = 0
@watcher.on(self.host, 'classic_remote_features')
def on_classic_remote_features(
handle: int,
status: int,
features: int,
page_number: int,
max_page_number: int,
) -> None:
if handle != connection.handle:
logger.warning(
"Received classic_remote_features for wrong handle, expected=0x%04X, got=0x%04X",
connection.handle,
handle,
)
return
if page_number != current_page_number:
logger.warning(
"Received classic_remote_features for wrong page, expected=%d, got=%d",
current_page_number,
page_number,
)
return
if status == hci.HCI_ErrorCode.SUCCESS:
read_feature_future.set_result((features, max_page_number))
else:
read_feature_future.set_exception(hci.HCI_Error(status))
await self.send_async_command(
hci.HCI_Read_Remote_Supported_Features_Command(
connection_handle=connection.handle
)
)
new_features, max_page_number = await read_feature_future
read_features |= new_features
if not (read_features & hci.LmpFeatureMask.EXTENDED_FEATURES):
return read_features
while current_page_number <= max_page_number:
read_feature_future = asyncio.get_running_loop().create_future()
await self.send_async_command(
hci.HCI_Read_Remote_Extended_Features_Command(
connection_handle=connection.handle,
page_number=current_page_number,
)
)
new_features, max_page_number = await read_feature_future
read_features |= new_features << (current_page_number * 64)
current_page_number += 1
return read_features
@utils.experimental('Only for testing.')
async def get_remote_cs_capabilities(
self, connection: Connection

View File

@@ -201,50 +201,51 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
value = data[2 : 2 + value_length]
typed_value: Any
if value_type == ValueType.END:
break
match value_type:
case ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
case ValueType.CNVI | ValueType.CNVR:
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
case ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
case (
ValueType.USB_VENDOR_ID
| ValueType.USB_PRODUCT_ID
| ValueType.DEVICE_REVISION
):
(typed_value,) = struct.unpack("<H", value)
case ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
case (
ValueType.BUILD_TYPE
| ValueType.BUILD_NUMBER
| ValueType.SECURE_BOOT
| ValueType.OTP_LOCK
| ValueType.API_LOCK
| ValueType.DEBUG_LOCK
| ValueType.SECURE_BOOT_ENGINE_TYPE
):
typed_value = value[0]
case ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
case ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
case ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
case _:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]

View File

@@ -1,60 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from bumble.gatt import (
GATT_APPEARANCE_CHARACTERISTIC,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
Characteristic,
Service,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)

View File

@@ -31,6 +31,7 @@ from typing import (
ClassVar,
Generic,
Literal,
SupportsBytes,
TypeVar,
cast,
)
@@ -247,28 +248,6 @@ HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
HCI_VERSION_BLUETOOTH_CORE_1_1: 'HCI_VERSION_BLUETOOTH_CORE_1_1',
HCI_VERSION_BLUETOOTH_CORE_1_2: 'HCI_VERSION_BLUETOOTH_CORE_1_2',
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_0_EDR',
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_1_EDR',
HCI_VERSION_BLUETOOTH_CORE_3_0_HS: 'HCI_VERSION_BLUETOOTH_CORE_3_0_HS',
HCI_VERSION_BLUETOOTH_CORE_4_0: 'HCI_VERSION_BLUETOOTH_CORE_4_0',
HCI_VERSION_BLUETOOTH_CORE_4_1: 'HCI_VERSION_BLUETOOTH_CORE_4_1',
HCI_VERSION_BLUETOOTH_CORE_4_2: 'HCI_VERSION_BLUETOOTH_CORE_4_2',
HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0',
HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1',
HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2',
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
}
LMP_VERSION_NAMES = HCI_VERSION_NAMES
# HCI Packet types
HCI_COMMAND_PACKET = 0x01
HCI_ACL_DATA_PACKET = 0x02
@@ -387,8 +366,8 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2A
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2B
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
@@ -1860,44 +1839,46 @@ class HCI_Object:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_type == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_length = data[offset]
offset += 1
field_value = data[offset : offset + field_length]
return (field_value, field_length + 1)
if field_type == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_type == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_type == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_type == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_type == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_type == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_type == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
match field_type:
case '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
case 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_length = data[offset]
offset += 1
field_value = data[offset : offset + field_length]
return (field_value, field_length + 1)
case 1:
# 8-bit unsigned
return (data[offset], 1)
case -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
case 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
case '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
case -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
case 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
case 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
case '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
case int() if 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type):
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
@@ -1954,60 +1935,58 @@ class HCI_Object:
# Serialize the field
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
return serializer(field_value)
match field_type:
case 1:
# 8-bit unsigned
return bytes([field_value])
case -1:
# 8-bit signed
return struct.pack('b', field_value)
case 2:
# 16-bit unsigned
return struct.pack('<H', field_value)
case '>2':
# 16-bit unsigned big-endian
return struct.pack('>H', field_value)
case -2:
# 16-bit signed
return struct.pack('<h', field_value)
case 3:
# 24-bit unsigned
return struct.pack('<I', field_value)[0:3]
case 4:
# 32-bit unsigned
return struct.pack('<I', field_value)
case '>4':
# 32-bit unsigned big-endian
return struct.pack('>I', field_value)
case '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
return bytes([field_value])
else:
raise InvalidArgumentError('value too large for *-typed field')
else:
raise InvalidArgumentError('value too large for *-typed field')
else:
return bytes(field_value)
case 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_bytes = bytes(field_value)
elif field_type == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_bytes = bytes(field_value)
field_length = len(field_bytes)
field_bytes = bytes([field_length]) + field_bytes
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, '__bytes__'
):
field_length = len(field_bytes)
return bytes([field_length]) + field_bytes
if isinstance(field_value, (bytes, bytearray, SupportsBytes)):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
return field_bytes + bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes[:field_type]
return field_bytes
return field_bytes
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
@staticmethod
def dict_to_bytes(hci_object, object_fields):
@@ -4736,7 +4715,7 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_SyncCommand[HCI_StatusReturnParame
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class HCI_LE_Read_Resolving_List_Size_ReturnParameters(HCI_StatusReturnParameters):
resolving_list_size: bytes = field(metadata=metadata(1))
resolving_list_size: int = field(metadata=metadata(1))
@HCI_SyncCommand.sync_command(HCI_LE_Read_Resolving_List_Size_ReturnParameters)

View File

@@ -26,7 +26,7 @@ import logging
import re
import traceback
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar
from typing import Any, ClassVar, Literal, overload
from typing_extensions import Self
@@ -420,61 +420,6 @@ class CmeError(enum.IntEnum):
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = {
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
}
# Unsolicited responses and statuses.
UNSOLICITED_CODES = {
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
}
# Status codes
STATUS_CODES = {
"+CME ERROR",
@@ -727,12 +672,9 @@ class HfProtocol(utils.EventEmitter):
dlc: rfcomm.DLC
command_lock: asyncio.Lock
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
pending_command: str | None = None
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
read_buffer: bytearray
active_codec: AudioCodec
@@ -805,16 +747,39 @@ class HfProtocol(utils.EventEmitter):
self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
if self.pending_command and (
response.code in STATUS_CODES or response.code in self.pending_command
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(
f"dropping unexpected response with code '{response.code}'"
)
self.unsolicited_queue.put_nowait(response)
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
) -> None: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.SINGLE],
) -> AtResponse: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.MULTIPLE],
) -> list[AtResponse]: ...
async def execute_command(
self,
@@ -835,27 +800,34 @@ class HfProtocol(utils.EventEmitter):
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
ProtocolError: the status is not OK.
"""
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
try:
async with self.command_lock:
self.pending_command = cmd
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise HfpProtocolError("NO ANSWER")
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if (
response_type == AtResponseType.SINGLE
and len(responses) != 1
):
raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
finally:
self.pending_command = None
async def initiate_slc(self):
"""4.2.1 Service Level Connection Initialization."""
@@ -1067,7 +1039,6 @@ class HfProtocol(utils.EventEmitter):
responses = await self.execute_command(
"AT+CLCC", response_type=AtResponseType.MULTIPLE
)
assert isinstance(responses, list)
calls = []
for response in responses:

View File

@@ -22,7 +22,7 @@ import collections
import dataclasses
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, TypeVar, overload
from bumble import drivers, hci, utils
from bumble.colors import color
@@ -1002,18 +1002,19 @@ class Host(utils.EventEmitter):
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
self.on_hci_command_packet(cast(hci.HCI_Command, packet))
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
self.on_hci_event_packet(cast(hci.HCI_Event, packet))
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
match packet:
case hci.HCI_Command():
self.on_hci_command_packet(packet)
case hci.HCI_Event():
self.on_hci_event_packet(packet)
case hci.HCI_AclDataPacket():
self.on_hci_acl_data_packet(packet)
case hci.HCI_SynchronousDataPacket():
self.on_hci_sco_data_packet(packet)
case hci.HCI_IsoDataPacket():
self.on_hci_iso_data_packet(packet)
case _:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}')
@@ -1659,6 +1660,19 @@ class Host(utils.EventEmitter):
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_read_remote_supported_features_complete_event(
self, event: hci.HCI_Read_Remote_Supported_Features_Complete_Event
) -> None:
# Notify the client
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.lmp_features, 'little'),
0, # page number
0, # max page number
)
def on_hci_encryption_change_v2_event(
self, event: hci.HCI_Encryption_Change_V2_Event
):
@@ -1815,6 +1829,18 @@ class Host(utils.EventEmitter):
rssi,
)
def on_hci_read_remote_extended_features_complete_event(
self, event: hci.HCI_Read_Remote_Extended_Features_Complete_Event
):
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.extended_lmp_features, 'little'),
event.page_number,
event.maximum_page_number,
)
def on_hci_extended_inquiry_result_event(
self, event: hci.HCI_Extended_Inquiry_Result_Event
):

View File

@@ -27,6 +27,7 @@ import dataclasses
import json
import logging
import os
import pathlib
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
@@ -248,29 +249,26 @@ class JsonKeyStore(KeyStore):
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
def __init__(
self, namespace: str | None = None, filename: str | None = None
) -> None:
self.namespace = namespace or self.DEFAULT_NAMESPACE
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
if filename:
self.filename = pathlib.Path(filename).resolve()
self.directory_name = self.filename.parent
else:
self.filename = filename
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
import platformdirs # Deferred import
logger.debug(f'JSON keystore: {self.filename}')
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR)
self.directory_name = base_dir / self.KEYS_DIR
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
safe_name = base_name.lower().replace(':', '-').replace('/', '-')
self.filename = self.directory_name / f"{safe_name}.json"
logger.debug('JSON keystore: %s', self.filename)
@classmethod
def from_device(
@@ -293,7 +291,9 @@ class JsonKeyStore(KeyStore):
return cls(namespace, filename)
async def load(self):
async def load(
self,
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
@@ -312,17 +312,17 @@ class JsonKeyStore(KeyStore):
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
key_map: dict[str, dict[str, Any]] = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db):
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None:
# Create the directory if it doesn't exist
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name, exist_ok=True)
if not self.directory_name.exists():
self.directory_name.mkdir(parents=True, exist_ok=True)
# Save to a temporary file
temp_filename = self.filename + '.tmp'
temp_filename = self.filename.with_name(self.filename.name + ".tmp")
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)
@@ -334,16 +334,16 @@ class JsonKeyStore(KeyStore):
del key_map[name]
await self.save(db)
async def update(self, name, keys):
async def update(self, name: str, keys: PairingKeys) -> None:
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self):
async def get_all(self) -> list[tuple[str, PairingKeys]]:
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self):
async def delete_all(self) -> None:
db, key_map = await self.load()
key_map.clear()
await self.save(db)

View File

@@ -198,3 +198,24 @@ class CisTerminateInd(ControlPdu):
cig_id: int
cis_id: int
error_code: int
@dataclasses.dataclass
class FeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
feature_set: bytes
@dataclasses.dataclass
class FeatureRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
feature_set: bytes
@dataclasses.dataclass
class PeripheralFeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
feature_set: bytes

View File

@@ -322,3 +322,38 @@ class LmpNameRes(Packet):
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))
@Packet.subclass
@dataclass
class LmpFeaturesReq(Packet):
opcode = Opcode.LMP_FEATURES_REQ
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesRes(Packet):
opcode = Opcode.LMP_FEATURES_RES
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesReqExt(Packet):
opcode = Opcode.LMP_FEATURES_REQ_EXT
features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesResExt(Packet):
opcode = Opcode.LMP_FEATURES_RES_EXT
features_page: int = field(metadata=hci.metadata(1))
max_features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))

View File

@@ -21,18 +21,9 @@ import enum
import secrets
from dataclasses import dataclass
from bumble import hci
from bumble import hci, smp
from bumble.core import AdvertisingData, LeRole
from bumble.smp import (
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
@@ -96,11 +87,11 @@ class PairingDelegate:
# These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants.
class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY
# Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
@@ -111,10 +102,10 @@ class PairingDelegate:
# Key Distribution [LE only]
class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY

View File

@@ -664,46 +664,44 @@ class AudioStreamControlService(gatt.TemplateService):
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if isinstance(operation, ASE_Config_Codec):
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
match operation:
case ASE_Config_Codec():
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Config_QOS():
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Enable() | ASE_Update_Metadata():
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case (
ASE_Receiver_Start_Ready()
| ASE_Disable()
| ASE_Receiver_Stop_Ready()
| ASE_Release()
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, ASE_Config_QOS):
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(
operation,
(
ASE_Receiver_Start_Ready,
ASE_Disable,
ASE_Receiver_Stop_Ready,
ASE_Release,
),
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]

View File

@@ -333,17 +333,18 @@ class CodecSpecificCapabilities:
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
match type:
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
case CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment

View File

@@ -55,14 +55,15 @@ class GenericAccessService(TemplateService):
def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
match appearance:
case int():
appearance_int = appearance
case tuple():
appearance_int = (appearance[0] << 6) | appearance[1]
case Appearance():
appearance_int = int(appearance)
case _:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,

View File

@@ -21,11 +21,12 @@ import asyncio
import logging
import struct
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, NewType
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar
from typing_extensions import Self
from bumble import core, l2cap
from bumble import core, hci, l2cap, utils
from bumble.colors import color
from bumble.core import (
InvalidArgumentError,
@@ -33,7 +34,6 @@ from bumble.core import (
InvalidStateError,
ProtocolError,
)
from bumble.hci import HCI_Object, key_with_value, name_or_number
if TYPE_CHECKING:
from bumble.device import Connection, Device
@@ -54,39 +54,22 @@ SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing
SDP_PSM = 0x0001
SDP_ERROR_RESPONSE = 0x01
SDP_SERVICE_SEARCH_REQUEST = 0x02
SDP_SERVICE_SEARCH_RESPONSE = 0x03
SDP_SERVICE_ATTRIBUTE_REQUEST = 0x04
SDP_SERVICE_ATTRIBUTE_RESPONSE = 0x05
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST = 0x06
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
class PduId(hci.SpecableEnum):
SDP_ERROR_RESPONSE = 0x01
SDP_SERVICE_SEARCH_REQUEST = 0x02
SDP_SERVICE_SEARCH_RESPONSE = 0x03
SDP_SERVICE_ATTRIBUTE_REQUEST = 0x04
SDP_SERVICE_ATTRIBUTE_RESPONSE = 0x05
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST = 0x06
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
SDP_PDU_NAMES = {
SDP_ERROR_RESPONSE: 'SDP_ERROR_RESPONSE',
SDP_SERVICE_SEARCH_REQUEST: 'SDP_SERVICE_SEARCH_REQUEST',
SDP_SERVICE_SEARCH_RESPONSE: 'SDP_SERVICE_SEARCH_RESPONSE',
SDP_SERVICE_ATTRIBUTE_REQUEST: 'SDP_SERVICE_ATTRIBUTE_REQUEST',
SDP_SERVICE_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_ATTRIBUTE_RESPONSE',
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: 'SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST',
SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE'
}
SDP_INVALID_SDP_VERSION_ERROR = 0x0001
SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR = 0x0002
SDP_INVALID_REQUEST_SYNTAX_ERROR = 0x0003
SDP_INVALID_PDU_SIZE_ERROR = 0x0004
SDP_INVALID_CONTINUATION_STATE_ERROR = 0x0005
SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR = 0x0006
SDP_ERROR_NAMES = {
SDP_INVALID_SDP_VERSION_ERROR: 'SDP_INVALID_SDP_VERSION_ERROR',
SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR: 'SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR',
SDP_INVALID_REQUEST_SYNTAX_ERROR: 'SDP_INVALID_REQUEST_SYNTAX_ERROR',
SDP_INVALID_PDU_SIZE_ERROR: 'SDP_INVALID_PDU_SIZE_ERROR',
SDP_INVALID_CONTINUATION_STATE_ERROR: 'SDP_INVALID_CONTINUATION_STATE_ERROR',
SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR: 'SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR'
}
class ErrorCode(hci.SpecableEnum):
INVALID_SDP_VERSION = 0x0001
INVALID_SERVICE_RECORD_HANDLE = 0x0002
INVALID_REQUEST_SYNTAX = 0x0003
INVALID_PDU_SIZE = 0x0004
INVALID_CONTINUATION_STATE = 0x0005
INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST = 0x0006
SDP_SERVICE_NAME_ATTRIBUTE_ID_OFFSET = 0x0000
SDP_SERVICE_DESCRIPTION_ATTRIBUTE_ID_OFFSET = 0x0001
@@ -141,30 +124,31 @@ SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF)
# -----------------------------------------------------------------------------
@dataclass
class DataElement:
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
TYPE_NAMES = {
NIL: 'NIL',
UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL',
}
class Type(utils.OpenIntEnum):
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
type_constructors = {
NIL = Type.NIL
UNSIGNED_INTEGER = Type.UNSIGNED_INTEGER
SIGNED_INTEGER = Type.SIGNED_INTEGER
UUID = Type.UUID
TEXT_STRING = Type.TEXT_STRING
BOOLEAN = Type.BOOLEAN
SEQUENCE = Type.SEQUENCE
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
@@ -190,14 +174,18 @@ class DataElement:
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
def __init__(self, element_type, value, value_size=None):
self.type = element_type
self.value = value
self.value_size = value_size
type: Type
value: Any
value_size: int | None = None
def __post_init__(self) -> None:
# Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
self._bytes: bytes | None = None
if self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
if self.value_size is None:
raise InvalidArgumentError(
'integer types must have a value size specified'
)
@@ -337,7 +325,7 @@ class DataElement:
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(element_type)
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
@@ -348,15 +336,15 @@ class DataElement:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result.bytes = data[
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
# Return early if we have a cache
if self.bytes:
return self.bytes
if self._bytes:
return self._bytes
if self.type == DataElement.NIL:
data = b''
@@ -443,12 +431,12 @@ class DataElement:
else:
raise RuntimeError("internal error - self.type not supported")
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self.bytes
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty=False, indentation=0):
prefix = ' ' * indentation
type_name = name_or_number(self.TYPE_NAMES, self.type)
type_name = self.type.name
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
@@ -476,10 +464,10 @@ class DataElement:
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
def __init__(self, attribute_id: int, value: DataElement) -> None:
self.id = attribute_id
self.value = value
id: int
value: DataElement
@staticmethod
def list_from_data_elements(
@@ -510,7 +498,7 @@ class ServiceAttribute:
@staticmethod
def id_name(id_code):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
return hci.name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
@@ -540,239 +528,223 @@ class ServiceAttribute:
# -----------------------------------------------------------------------------
def _parse_service_record_handle_list(
data: bytes, offset: int
) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset)[0]
offset += 2
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
def _serialize_service_record_handle_list(
handles: list[int],
) -> bytes:
return struct.pack('>H', len(handles)) + b''.join(
struct.pack('>I', handle) for handle in handles
)
def _parse_bytes_preceded_by_length(data: bytes, offset: int) -> tuple[int, bytes]:
length = struct.unpack_from('>H', data, offset)[0]
offset += 2
return offset + length, data[offset : offset + length]
def _serialize_bytes_preceded_by_length(data: bytes) -> bytes:
return struct.pack('>H', len(data)) + data
_SERVICE_RECORD_HANDLE_LIST_METADATA = hci.metadata(
{
'parser': _parse_service_record_handle_list,
'serializer': _serialize_service_record_handle_list,
}
)
_BYTES_PRECEDED_BY_LENGTH_METADATA = hci.metadata(
{
'parser': _parse_bytes_preceded_by_length,
'serializer': _serialize_bytes_preceded_by_length,
}
)
# -----------------------------------------------------------------------------
@dataclass
class SDP_PDU:
'''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
RESPONSE_PDU_IDS = {
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
PduId.SDP_SERVICE_SEARCH_REQUEST: PduId.SDP_SERVICE_SEARCH_RESPONSE,
PduId.SDP_SERVICE_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE,
PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
}
sdp_pdu_classes: dict[int, type[SDP_PDU]] = {}
name = None
pdu_id = 0
subclasses: ClassVar[dict[int, type[SDP_PDU]]] = {}
pdu_id: ClassVar[PduId]
fields: ClassVar[hci.Fields]
@staticmethod
def from_bytes(pdu):
transaction_id: int
_payload: bytes | None = field(init=False, repr=False, default=None)
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None:
instance = SDP_PDU(pdu)
instance.name = SDP_PDU.pdu_name(pdu_id)
instance.pdu_id = pdu_id
instance.transaction_id = transaction_id
return instance
self = cls.__new__(cls)
SDP_PDU.__init__(self, pdu, transaction_id)
if hasattr(self, 'fields'):
self.init_from_bytes(pdu, 5)
return self
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
raise InvalidPacketError(f"Unknown PDU type {pdu_id}")
instance = subclass(
transaction_id=transaction_id,
**hci.HCI_Object.dict_from_bytes(pdu, 5, subclass.fields),
)
instance._payload = pdu
return instance
@staticmethod
def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int
) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
_PDU = TypeVar('_PDU', bound='SDP_PDU')
@staticmethod
def parse_bytes_preceded_by_length(data, offset):
length = struct.unpack_from('>H', data, offset - 2)[0]
return offset + length, data[offset : offset + length]
@staticmethod
def error_name(error_code):
return name_or_number(SDP_ERROR_NAMES, error_code)
@staticmethod
def pdu_name(code):
return name_or_number(SDP_PDU_NAMES, code)
@staticmethod
def subclass(fields):
def inner(cls):
name = cls.__name__
# add a _ character before every uppercase letter, except the SDP_ prefix
location = len(name) - 1
while location > 4:
if not name[location].isupper():
location -= 1
continue
name = name[:location] + '_' + name[location:]
location -= 1
cls.name = name.upper()
cls.pdu_id = key_with_value(SDP_PDU_NAMES, cls.name)
if cls.pdu_id is None:
raise KeyError(f'PDU name {cls.name} not found in SDP_PDU_NAMES')
cls.fields = fields
# Register a factory for this class
SDP_PDU.sdp_pdu_classes[cls.pdu_id] = cls
return cls
return inner
def __init__(self, pdu=None, transaction_id=0, **kwargs):
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu
self.transaction_id = transaction_id
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
@classmethod
def subclass(cls, subclass: type[_PDU]) -> type[_PDU]:
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
cls.subclasses[subclass.pdu_id] = subclass
return subclass
def __bytes__(self):
return self.pdu
if self._payload is None:
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@property
def name(self) -> str:
return self.pdu_id.name
def __str__(self):
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
if fields := getattr(self, 'fields', None):
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ')
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
elif len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
@SDP_PDU.subclass
@dataclass
class SDP_ErrorResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
'''
error_code: int
pdu_id = PduId.SDP_ERROR_RESPONSE
error_code: ErrorCode = field(metadata=ErrorCode.type_metadata(2))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceSearchRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
'''
service_search_pattern: DataElement
maximum_service_record_count: int
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_SEARCH_REQUEST
service_search_pattern: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
maximum_service_record_count: int = field(metadata=hci.metadata('>2'))
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
(
'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceSearchResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
'''
service_record_handle_list: list[int]
total_service_record_count: int
current_service_record_count: int
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_SEARCH_RESPONSE
total_service_record_count: int = field(metadata=hci.metadata('>2'))
service_record_handle_list: Sequence[int] = field(
metadata=_SERVICE_RECORD_HANDLE_LIST_METADATA
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
'''
service_record_handle: int
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_REQUEST
service_record_handle: int = field(metadata=hci.metadata('>4'))
maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2'))
attribute_id_list: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
'''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE
attribute_list: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA)
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceSearchAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
'''
service_search_pattern: DataElement
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST
service_search_pattern: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2'))
attribute_id_list: DataElement = field(
metadata=hci.metadata(DataElement.parse_from_bytes)
)
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass
@dataclass
class SDP_ServiceSearchAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
'''
attribute_lists_byte_count: int
attribute_lists: bytes
continuation_state: bytes
pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE
attribute_lists: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA)
continuation_state: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@@ -873,7 +845,7 @@ class Client:
)
# Request and accumulate until there's no more continuation
service_record_handle_list = []
service_record_handle_list: list[int] = []
continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0:
@@ -1091,7 +1063,7 @@ class Server:
logger.exception(color('failed to parse SDP Request PDU', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
transaction_id=0, error_code=ErrorCode.INVALID_REQUEST_SYNTAX
)
)
@@ -1108,7 +1080,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
error_code=ErrorCode.INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST,
)
)
else:
@@ -1116,7 +1088,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
error_code=ErrorCode.INVALID_REQUEST_SYNTAX,
)
)
@@ -1134,7 +1106,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
error_code=ErrorCode.INVALID_CONTINUATION_STATE,
)
)
return None
@@ -1228,15 +1200,11 @@ class Server:
if service_record_handles_remaining
else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
)
self.send_response(
SDP_ServiceSearchResponse(
transaction_id=request.transaction_id,
total_service_record_count=total_service_record_count,
current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list,
service_record_handle_list=service_record_handles,
continuation_state=continuation_state,
)
)
@@ -1259,7 +1227,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
error_code=ErrorCode.INVALID_SERVICE_RECORD_HANDLE,
)
)
return
@@ -1284,7 +1252,6 @@ class Server:
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list_response,
continuation_state=continuation_state,
)
@@ -1331,7 +1298,6 @@ class Server:
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists_response,
continuation_state=continuation_state,
)

View File

@@ -31,14 +31,13 @@ from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
from bumble import crypto, utils
from bumble import crypto, hci, utils
from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
PhysicalTransport,
ProtocolError,
name_or_number,
)
from bumble.hci import (
Address,
@@ -46,7 +45,6 @@ from bumble.hci import (
HCI_LE_Enable_Encryption_Command,
HCI_Object,
Role,
key_with_value,
metadata,
)
from bumble.keys import PairingKeys
@@ -71,115 +69,125 @@ logger = logging.getLogger(__name__)
SMP_CID = 0x06
SMP_BR_CID = 0x07
SMP_PAIRING_REQUEST_COMMAND = 0x01
SMP_PAIRING_RESPONSE_COMMAND = 0x02
SMP_PAIRING_CONFIRM_COMMAND = 0x03
SMP_PAIRING_RANDOM_COMMAND = 0x04
SMP_PAIRING_FAILED_COMMAND = 0x05
SMP_ENCRYPTION_INFORMATION_COMMAND = 0x06
SMP_MASTER_IDENTIFICATION_COMMAND = 0x07
SMP_IDENTITY_INFORMATION_COMMAND = 0x08
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND = 0x09
SMP_SIGNING_INFORMATION_COMMAND = 0x0A
SMP_SECURITY_REQUEST_COMMAND = 0x0B
SMP_PAIRING_PUBLIC_KEY_COMMAND = 0x0C
SMP_PAIRING_DHKEY_CHECK_COMMAND = 0x0D
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND = 0x0E
class CommandCode(hci.SpecableEnum):
PAIRING_REQUEST = 0x01
PAIRING_RESPONSE = 0x02
PAIRING_CONFIRM = 0x03
PAIRING_RANDOM = 0x04
PAIRING_FAILED = 0x05
ENCRYPTION_INFORMATION = 0x06
MASTER_IDENTIFICATION = 0x07
IDENTITY_INFORMATION = 0x08
IDENTITY_ADDRESS_INFORMATION = 0x09
SIGNING_INFORMATION = 0x0A
SECURITY_REQUEST = 0x0B
PAIRING_PUBLIC_KEY = 0x0C
PAIRING_DHKEY_CHECK = 0x0D
PAIRING_KEYPRESS_NOTIFICATION = 0x0E
SMP_COMMAND_NAMES = {
SMP_PAIRING_REQUEST_COMMAND: 'SMP_PAIRING_REQUEST_COMMAND',
SMP_PAIRING_RESPONSE_COMMAND: 'SMP_PAIRING_RESPONSE_COMMAND',
SMP_PAIRING_CONFIRM_COMMAND: 'SMP_PAIRING_CONFIRM_COMMAND',
SMP_PAIRING_RANDOM_COMMAND: 'SMP_PAIRING_RANDOM_COMMAND',
SMP_PAIRING_FAILED_COMMAND: 'SMP_PAIRING_FAILED_COMMAND',
SMP_ENCRYPTION_INFORMATION_COMMAND: 'SMP_ENCRYPTION_INFORMATION_COMMAND',
SMP_MASTER_IDENTIFICATION_COMMAND: 'SMP_MASTER_IDENTIFICATION_COMMAND',
SMP_IDENTITY_INFORMATION_COMMAND: 'SMP_IDENTITY_INFORMATION_COMMAND',
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND: 'SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND',
SMP_SIGNING_INFORMATION_COMMAND: 'SMP_SIGNING_INFORMATION_COMMAND',
SMP_SECURITY_REQUEST_COMMAND: 'SMP_SECURITY_REQUEST_COMMAND',
SMP_PAIRING_PUBLIC_KEY_COMMAND: 'SMP_PAIRING_PUBLIC_KEY_COMMAND',
SMP_PAIRING_DHKEY_CHECK_COMMAND: 'SMP_PAIRING_DHKEY_CHECK_COMMAND',
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND: 'SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND'
}
SMP_DISPLAY_ONLY_IO_CAPABILITY = 0x00
SMP_DISPLAY_YES_NO_IO_CAPABILITY = 0x01
SMP_KEYBOARD_ONLY_IO_CAPABILITY = 0x02
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = 0x03
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = 0x04
class IoCapability(hci.SpecableEnum):
DISPLAY_ONLY = 0x00
DISPLAY_YES_NO = 0x01
KEYBOARD_ONLY = 0x02
NO_INPUT_NO_OUTPUT = 0x03
KEYBOARD_DISPLAY = 0x04
SMP_IO_CAPABILITY_NAMES = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: 'SMP_DISPLAY_ONLY_IO_CAPABILITY',
SMP_DISPLAY_YES_NO_IO_CAPABILITY: 'SMP_DISPLAY_YES_NO_IO_CAPABILITY',
SMP_KEYBOARD_ONLY_IO_CAPABILITY: 'SMP_KEYBOARD_ONLY_IO_CAPABILITY',
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: 'SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY',
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: 'SMP_KEYBOARD_DISPLAY_IO_CAPABILITY'
}
SMP_DISPLAY_ONLY_IO_CAPABILITY = IoCapability.DISPLAY_ONLY
SMP_DISPLAY_YES_NO_IO_CAPABILITY = IoCapability.DISPLAY_YES_NO
SMP_KEYBOARD_ONLY_IO_CAPABILITY = IoCapability.KEYBOARD_ONLY
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = IoCapability.NO_INPUT_NO_OUTPUT
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = IoCapability.KEYBOARD_DISPLAY
SMP_PASSKEY_ENTRY_FAILED_ERROR = 0x01
SMP_OOB_NOT_AVAILABLE_ERROR = 0x02
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = 0x03
SMP_CONFIRM_VALUE_FAILED_ERROR = 0x04
SMP_PAIRING_NOT_SUPPORTED_ERROR = 0x05
SMP_ENCRYPTION_KEY_SIZE_ERROR = 0x06
SMP_COMMAND_NOT_SUPPORTED_ERROR = 0x07
SMP_UNSPECIFIED_REASON_ERROR = 0x08
SMP_REPEATED_ATTEMPTS_ERROR = 0x09
SMP_INVALID_PARAMETERS_ERROR = 0x0A
SMP_DHKEY_CHECK_FAILED_ERROR = 0x0B
SMP_NUMERIC_COMPARISON_FAILED_ERROR = 0x0C
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = 0x0D
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = 0x0E
class ErrorCode(hci.SpecableEnum):
PASSKEY_ENTRY_FAILED = 0x01
OOB_NOT_AVAILABLE = 0x02
AUTHENTICATION_REQUIREMENTS = 0x03
CONFIRM_VALUE_FAILED = 0x04
PAIRING_NOT_SUPPORTED = 0x05
ENCRYPTION_KEY_SIZE = 0x06
COMMAND_NOT_SUPPORTED = 0x07
UNSPECIFIED_REASON = 0x08
REPEATED_ATTEMPTS = 0x09
INVALID_PARAMETERS = 0x0A
DHKEY_CHECK_FAILED = 0x0B
NUMERIC_COMPARISON_FAILED = 0x0C
BD_EDR_PAIRING_IN_PROGRESS = 0x0D
CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED = 0x0E
SMP_ERROR_NAMES = {
SMP_PASSKEY_ENTRY_FAILED_ERROR: 'SMP_PASSKEY_ENTRY_FAILED_ERROR',
SMP_OOB_NOT_AVAILABLE_ERROR: 'SMP_OOB_NOT_AVAILABLE_ERROR',
SMP_AUTHENTICATION_REQUIREMENTS_ERROR: 'SMP_AUTHENTICATION_REQUIREMENTS_ERROR',
SMP_CONFIRM_VALUE_FAILED_ERROR: 'SMP_CONFIRM_VALUE_FAILED_ERROR',
SMP_PAIRING_NOT_SUPPORTED_ERROR: 'SMP_PAIRING_NOT_SUPPORTED_ERROR',
SMP_ENCRYPTION_KEY_SIZE_ERROR: 'SMP_ENCRYPTION_KEY_SIZE_ERROR',
SMP_COMMAND_NOT_SUPPORTED_ERROR: 'SMP_COMMAND_NOT_SUPPORTED_ERROR',
SMP_UNSPECIFIED_REASON_ERROR: 'SMP_UNSPECIFIED_REASON_ERROR',
SMP_REPEATED_ATTEMPTS_ERROR: 'SMP_REPEATED_ATTEMPTS_ERROR',
SMP_INVALID_PARAMETERS_ERROR: 'SMP_INVALID_PARAMETERS_ERROR',
SMP_DHKEY_CHECK_FAILED_ERROR: 'SMP_DHKEY_CHECK_FAILED_ERROR',
SMP_NUMERIC_COMPARISON_FAILED_ERROR: 'SMP_NUMERIC_COMPARISON_FAILED_ERROR',
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR: 'SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR',
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR: 'SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR'
}
SMP_PASSKEY_ENTRY_FAILED_ERROR = ErrorCode.PASSKEY_ENTRY_FAILED
SMP_OOB_NOT_AVAILABLE_ERROR = ErrorCode.OOB_NOT_AVAILABLE
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = ErrorCode.AUTHENTICATION_REQUIREMENTS
SMP_CONFIRM_VALUE_FAILED_ERROR = ErrorCode.CONFIRM_VALUE_FAILED
SMP_PAIRING_NOT_SUPPORTED_ERROR = ErrorCode.PAIRING_NOT_SUPPORTED
SMP_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.ENCRYPTION_KEY_SIZE
SMP_COMMAND_NOT_SUPPORTED_ERROR = ErrorCode.COMMAND_NOT_SUPPORTED
SMP_UNSPECIFIED_REASON_ERROR = ErrorCode.UNSPECIFIED_REASON
SMP_REPEATED_ATTEMPTS_ERROR = ErrorCode.REPEATED_ATTEMPTS
SMP_INVALID_PARAMETERS_ERROR = ErrorCode.INVALID_PARAMETERS
SMP_DHKEY_CHECK_FAILED_ERROR = ErrorCode.DHKEY_CHECK_FAILED
SMP_NUMERIC_COMPARISON_FAILED_ERROR = ErrorCode.NUMERIC_COMPARISON_FAILED
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = ErrorCode.BD_EDR_PAIRING_IN_PROGRESS
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE = 0
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE = 1
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE = 2
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE = 3
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE = 4
SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES = {
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE'
}
class KeypressNotificationType(hci.SpecableEnum):
PASSKEY_ENTRY_STARTED = 0
PASSKEY_DIGIT_ENTERED = 1
PASSKEY_DIGIT_ERASED = 2
PASSKEY_CLEARED = 3
PASSKEY_ENTRY_COMPLETED = 4
# Bit flags for key distribution/generation
SMP_ENC_KEY_DISTRIBUTION_FLAG = 0b0001
SMP_ID_KEY_DISTRIBUTION_FLAG = 0b0010
SMP_SIGN_KEY_DISTRIBUTION_FLAG = 0b0100
SMP_LINK_KEY_DISTRIBUTION_FLAG = 0b1000
class KeyDistribution(hci.SpecableFlag):
ENC_KEY = 0b0001
ID_KEY = 0b0010
SIGN_KEY = 0b0100
LINK_KEY = 0b1000
# AuthReq fields
SMP_BONDING_AUTHREQ = 0b00000001
SMP_MITM_AUTHREQ = 0b00000100
SMP_SC_AUTHREQ = 0b00001000
SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000
class AuthReq(hci.SpecableFlag):
BONDING = 0b00000001
MITM = 0b00000100
SC = 0b00001000
KEYPRESS = 0b00010000
CT2 = 0b00100000
@classmethod
def from_booleans(
cls,
bonding: bool = False,
sc: bool = False,
mitm: bool = False,
keypress: bool = False,
ct2: bool = False,
) -> AuthReq:
auth_req = AuthReq(0)
if bonding:
auth_req |= AuthReq.BONDING
if sc:
auth_req |= AuthReq.SC
if mitm:
auth_req |= AuthReq.MITM
if keypress:
auth_req |= AuthReq.KEYPRESS
if ct2:
auth_req |= AuthReq.CT2
return auth_req
# Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# Diffie-Hellman private / public key pair in Debug Mode (Core - Vol. 3, Part H)
SMP_DEBUG_KEY_PRIVATE = bytes.fromhex(
'3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd'
)
SMP_DEBUG_KEY_PUBLIC_X = bytes.fromhex(
'20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
)
SMP_DEBUG_KEY_PUBLIC_Y= bytes.fromhex(
'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b'
)
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
@@ -188,8 +196,6 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code)
# -----------------------------------------------------------------------------
@@ -201,20 +207,20 @@ class SMP_Command:
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
'''
smp_classes: ClassVar[dict[int, type[SMP_Command]]] = {}
smp_classes: ClassVar[dict[CommandCode, type[SMP_Command]]] = {}
fields: ClassVar[Fields]
code: int = field(default=0, init=False)
code: CommandCode = field(default=CommandCode(0), init=False)
name: str = field(default='', init=False)
_payload: bytes | None = field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
code = pdu[0]
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)
if subclass is None:
instance = SMP_Command()
instance.name = SMP_Command.command_name(code)
instance.name = code.name
instance.code = code
instance.payload = pdu
return instance
@@ -222,59 +228,14 @@ class SMP_Command:
instance.payload = pdu[1:]
return instance
@staticmethod
def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod
def auth_req_str(value: int) -> str:
bonding_flags = value & 3
mitm = (value >> 2) & 1
sc = (value >> 3) & 1
keypress = (value >> 4) & 1
ct2 = (value >> 5) & 1
return (
f'bonding_flags={bonding_flags}, '
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
)
@staticmethod
def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod
def key_distribution_str(value: int) -> str:
key_types: list[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
key_types.append('ID')
if value & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
key_types.append('SIGN')
if value & SMP_LINK_KEY_DISTRIBUTION_FLAG:
key_types.append('LINK')
return ','.join(key_types)
@staticmethod
def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
_Command = TypeVar("_Command", bound="SMP_Command")
@classmethod
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
subclass.name = subclass.__name__.upper()
subclass.code = key_with_value(SMP_COMMAND_NAMES, subclass.name)
if subclass.code is None:
raise KeyError(
f'Command name {subclass.name} not found in SMP_COMMAND_NAMES'
)
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
subclass.name = subclass.__name__.upper()
# Register a factory for this class
SMP_Command.smp_classes[subclass.code] = subclass
return subclass
@property
@@ -308,19 +269,17 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
'''
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
code = CommandCode.PAIRING_REQUEST
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
@@ -332,19 +291,17 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
'''
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
code = CommandCode.PAIRING_RESPONSE
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
@@ -356,6 +313,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
'''
code = CommandCode.PAIRING_CONFIRM
confirm_value: bytes = field(metadata=metadata(16))
@@ -367,6 +326,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
'''
code = CommandCode.PAIRING_RANDOM
random_value: bytes = field(metadata=metadata(16))
@@ -378,7 +339,9 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
'''
reason: int = field(metadata=metadata({'size': 1, 'mapper': error_name}))
code = CommandCode.PAIRING_FAILED
reason: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
# -----------------------------------------------------------------------------
@@ -389,6 +352,8 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
'''
code = CommandCode.PAIRING_PUBLIC_KEY
public_key_x: bytes = field(metadata=metadata(32))
public_key_y: bytes = field(metadata=metadata(32))
@@ -401,6 +366,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
'''
code = CommandCode.PAIRING_DHKEY_CHECK
dhkey_check: bytes = field(metadata=metadata(16))
@@ -412,10 +379,10 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
'''
notification_type: int = field(
metadata=metadata(
{'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}
)
code = CommandCode.PAIRING_KEYPRESS_NOTIFICATION
notification_type: KeypressNotificationType = field(
metadata=KeypressNotificationType.type_metadata(1)
)
@@ -427,6 +394,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
'''
code = CommandCode.ENCRYPTION_INFORMATION
long_term_key: bytes = field(metadata=metadata(16))
@@ -438,6 +407,8 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
'''
code = CommandCode.MASTER_IDENTIFICATION
ediv: int = field(metadata=metadata(2))
rand: bytes = field(metadata=metadata(8))
@@ -450,6 +421,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
'''
code = CommandCode.IDENTITY_INFORMATION
identity_resolving_key: bytes = field(metadata=metadata(16))
@@ -461,6 +434,8 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
'''
code = CommandCode.IDENTITY_ADDRESS_INFORMATION
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
@@ -473,6 +448,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
'''
code = CommandCode.SIGNING_INFORMATION
signature_key: bytes = field(metadata=metadata(16))
@@ -484,25 +461,9 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
'''
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
code = CommandCode.SECURITY_REQUEST
# -----------------------------------------------------------------------------
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0
if bonding:
value |= SMP_BONDING_AUTHREQ
if mitm:
value |= SMP_MITM_AUTHREQ
if sc:
value |= SMP_SC_AUTHREQ
if keypress:
value |= SMP_KEYPRESS_AUTHREQ
if ct2:
value |= SMP_CT2_AUTHREQ
return value
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
# -----------------------------------------------------------------------------
@@ -676,8 +637,8 @@ class Session:
self.ltk_rand = bytes(8)
self.link_key: bytes | None = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.initiator_key_distribution: KeyDistribution = KeyDistribution(0)
self.responder_key_distribution: KeyDistribution = KeyDistribution(0)
self.peer_random_value: bytes | None = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
@@ -728,10 +689,10 @@ class Session:
)
# Key Distribution (default values before negotiation)
self.initiator_key_distribution = (
self.initiator_key_distribution = KeyDistribution(
pairing_config.delegate.local_initiator_key_distribution
)
self.responder_key_distribution = (
self.responder_key_distribution = KeyDistribution(
pairing_config.delegate.local_responder_key_distribution
)
@@ -743,7 +704,7 @@ class Session:
self.ct2: bool = False
# I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability
self.io_capability = IoCapability(pairing_config.delegate.io_capability)
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
@@ -822,8 +783,14 @@ class Session:
return self.nx[0 if self.is_responder else 1]
@property
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def auth_req(self) -> AuthReq:
return AuthReq.from_booleans(
bonding=self.bonding,
sc=self.sc,
mitm=self.mitm,
keypress=self.keypress,
ct2=self.ct2,
)
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
if not self.sc and not self.completed:
@@ -843,7 +810,7 @@ class Session:
if self.connection.transport == PhysicalTransport.BR_EDR:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
if (not self.mitm) and (auth_req & AuthReq.MITM == 0):
self.pairing_method = PairingMethod.JUST_WORKS
return
@@ -861,7 +828,7 @@ class Session:
self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(
self, expected: bytes, received: bytes, error: int
self, expected: bytes, received: bytes, error: ErrorCode
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received:
@@ -881,7 +848,7 @@ class Session:
except Exception:
logger.exception('exception while confirm')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -900,7 +867,7 @@ class Session:
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -911,13 +878,13 @@ class Session:
passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
return
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -972,7 +939,7 @@ class Session:
def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error: int) -> None:
def send_pairing_failed(self, error: ErrorCode) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error)
@@ -1144,7 +1111,7 @@ class Session:
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
@@ -1155,14 +1122,14 @@ class Session:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
and self.initiator_key_distribution & KeyDistribution.ENC_KEY
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.ENC_KEY:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1173,7 +1140,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.ID_KEY:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1183,25 +1150,25 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.SIGN_KEY:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.LINK_KEY:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
else:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
and self.responder_key_distribution & KeyDistribution.ENC_KEY
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.ENC_KEY:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1212,7 +1179,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.ID_KEY:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1222,30 +1189,30 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.SIGN_KEY:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.LINK_KEY:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = []
if not self.sc and self.connection.transport == PhysicalTransport.LE:
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.ENC_KEY != 0:
self.peer_expected_distributions.append(
SMP_Encryption_Information_Command
)
self.peer_expected_distributions.append(
SMP_Master_Identification_Command
)
if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.ID_KEY != 0:
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
self.peer_expected_distributions.append(
SMP_Identity_Address_Information_Command
)
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.SIGN_KEY != 0:
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
logger.debug(
'expecting distributions: '
@@ -1258,7 +1225,7 @@ class Session:
logger.warning(
color('received key distribution on a non-encrypted connection', 'red')
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
return
# Check that this command class is expected
@@ -1278,7 +1245,7 @@ class Session:
'red',
)
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
async def pair(self) -> None:
# Start pairing as an initiator
@@ -1389,34 +1356,56 @@ class Session:
)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
def on_pairing_failure(self, reason: ErrorCode) -> None:
logger.warning('pairing failure (%s)', reason.name)
if self.completed:
return
self.completed = True
error = ProtocolError(reason, 'smp', error_name(reason))
error = ProtocolError(reason, 'smp', reason.name)
if self.pairing_result is not None and not self.pairing_result.done():
self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(command)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
self.send_command(response)
else:
logger.error(color('SMP command not handled???', 'red'))
try:
match command:
case SMP_Pairing_Request_Command():
self.on_smp_pairing_request_command(command)
case SMP_Pairing_Response_Command():
self.on_smp_pairing_response_command(command)
case SMP_Pairing_Confirm_Command():
self.on_smp_pairing_confirm_command(command)
case SMP_Pairing_Random_Command():
self.on_smp_pairing_random_command(command)
case SMP_Pairing_Failed_Command():
self.on_smp_pairing_failed_command(command)
case SMP_Encryption_Information_Command():
self.on_smp_encryption_information_command(command)
case SMP_Master_Identification_Command():
self.on_smp_master_identification_command(command)
case SMP_Identity_Information_Command():
self.on_smp_identity_information_command(command)
case SMP_Identity_Address_Information_Command():
self.on_smp_identity_address_information_command(command)
case SMP_Signing_Information_Command():
self.on_smp_signing_information_command(command)
case SMP_Pairing_Public_Key_Command():
self.on_smp_pairing_public_key_command(command)
case SMP_Pairing_DHKey_Check_Command():
self.on_smp_pairing_dhkey_check_command(command)
# case SMP_Security_Request_Command():
# self.on_smp_security_request_command(command)
# case SMP_Pairing_Keypress_Notification_Command():
# self.on_smp_pairing_keypress_notification_command(command)
case _:
logger.error(color('SMP command not handled', 'red'))
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(reason=ErrorCode.UNSPECIFIED_REASON)
self.send_command(response)
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
@@ -1436,16 +1425,16 @@ class Session:
accepted = False
if not accepted:
logger.debug('pairing rejected by delegate')
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
self.send_pairing_failed(ErrorCode.PAIRING_NOT_SUPPORTED)
return
# Save the request
self.preq = bytes(command)
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
self.ct2 = self.ct2 and (command.auth_req & AuthReq.CT2 != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1456,7 +1445,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1475,8 +1464,11 @@ class Session:
(
self.initiator_key_distribution,
self.responder_key_distribution,
) = await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
) = map(
KeyDistribution,
await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
),
)
self.compute_peer_expected_distributions(self.initiator_key_distribution)
@@ -1514,8 +1506,8 @@ class Session:
self.peer_io_capability = command.io_capability
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1526,7 +1518,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1546,7 +1538,7 @@ class Session:
command.responder_key_distribution & ~self.responder_key_distribution != 0
):
# The response isn't a subset of the request
self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR)
self.send_pairing_failed(ErrorCode.INVALID_PARAMETERS)
return
self.initiator_key_distribution = command.initiator_key_distribution
self.responder_key_distribution = command.responder_key_distribution
@@ -1624,7 +1616,7 @@ class Session:
)
assert self.confirm_value
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1665,7 +1657,7 @@ class Session:
self.pkb, self.pka, command.random_value, bytes([0])
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1678,7 +1670,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1707,7 +1699,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1824,7 +1816,7 @@ class Session:
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
ErrorCode.CONFIRM_VALUE_FAILED,
):
return
@@ -1858,7 +1850,7 @@ class Session:
expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
expected, command.dhkey_check, ErrorCode.DHKEY_CHECK_FAILED
):
return
@@ -1937,6 +1929,7 @@ class Manager(utils.EventEmitter):
self._ecc_key = None
self.pairing_config_factory = pairing_config_factory
self.session_proxy = Session
self.debug_mode = False
def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug(
@@ -1962,7 +1955,7 @@ class Manager(utils.EventEmitter):
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
if command.code == CommandCode.SECURITY_REQUEST:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
@@ -1983,6 +1976,13 @@ class Manager(utils.EventEmitter):
@property
def ecc_key(self) -> crypto.EccKey:
if self.debug_mode:
# Core - Vol 3, Part H:
# When the Security Manager is placed in a Debug mode it shall use the
# following Diffie-Hellman private / public key pair:
debug_key = crypto.EccKey.from_private_key_bytes(SMP_DEBUG_KEY_PRIVATE)
return debug_key
if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
@@ -2002,15 +2002,13 @@ class Manager(utils.EventEmitter):
def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection)
if pairing_config:
auth_req = smp_auth_req(
pairing_config.bonding,
pairing_config.mitm,
pairing_config.sc,
False,
False,
auth_req = AuthReq.from_booleans(
bonding=pairing_config.bonding,
sc=pairing_config.sc,
mitm=pairing_config.mitm,
)
else:
auth_req = 0
auth_req = AuthReq(0)
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session: Session) -> None:
@@ -2026,7 +2024,7 @@ class Manager(utils.EventEmitter):
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: int) -> None:
def on_pairing_failure(self, session: Session, reason: ErrorCode) -> None:
self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session: Session) -> None:

View File

@@ -133,10 +133,10 @@ def on_avrcp_start(
utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed() -> None:
async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex())
async for uid in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", hex(uid))
websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
{"type": "track-changed", "params": {"identifier": hex(uid)}}
)
async def monitor_playback_status() -> None:

View File

@@ -83,6 +83,7 @@ async def main() -> None:
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
server_device.add_service(device_info_service)
await server_device.start_advertising()
# Connect the client to the server
connection = await client_device.connect(server_device.random_address)

View File

@@ -13,13 +13,12 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.10"
dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
"cryptography >= 44.0.3; platform_system=='Emscripten'",
"cryptography >= 39.0.0; platform_system=='Emscripten'",
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
# updated. Relax the version requirement since it's better than being completely unable
@@ -37,7 +36,7 @@ dependencies = [
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
"pyserial >= 3.5; platform_system!='Emscripten'",
"pyusb >= 1.2; platform_system!='Emscripten'",
"tomli ~= 2.2.1; platform_system!='Emscripten'",
"tomli ~= 2.2.1; platform_system!='Emscripten' and python_version<'3.11'",
"websockets >= 15.0.1; platform_system!='Emscripten'",
]

56
rust/Cargo.lock generated
View File

@@ -221,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
[[package]]
name = "bytes"
version = "1.5.0"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]]
name = "cc"
@@ -657,6 +657,18 @@ dependencies = [
"wasi",
]
[[package]]
name = "getrandom"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasip2",
]
[[package]]
name = "gimli"
version = "0.28.0"
@@ -1402,21 +1414,26 @@ dependencies = [
]
[[package]]
name = "rand"
version = "0.8.5"
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
@@ -1424,11 +1441,11 @@ dependencies = [
[[package]]
name = "rand_core"
version = "0.6.4"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
dependencies = [
"getrandom",
"getrandom 0.3.4",
]
[[package]]
@@ -1455,7 +1472,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.10",
"redox_syscall 0.2.16",
"thiserror",
]
@@ -2028,6 +2045,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasip2"
version = "1.0.2+wasi-0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.87"
@@ -2283,3 +2309,9 @@ dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"

View File

@@ -30,7 +30,7 @@ hex = "0.4.3"
itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
bytes = "1.5.0"
bytes = "1.11.1"
pdl-derive = "0.2.0"
pdl-runtime = "0.2.0"
futures = "0.3.28"
@@ -57,7 +57,7 @@ anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
rusb = "0.9.2"
rand = "0.8.5"
rand = "0.9.3"
clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0"
log = "0.4.19"

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio
import struct
from collections.abc import Sequence
from unittest import mock
import pytest
@@ -118,8 +119,6 @@ class TwoDevices(test_utils.TwoDevices):
scope=avrcp.Scope.NOW_PLAYING,
uid=0,
uid_counter=1,
start_item=0,
end_item=0,
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
),
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
@@ -136,7 +135,7 @@ def test_command(command: avrcp.Command):
"event,",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(identifier=b'12356'),
avrcp.TrackChangedEvent(uid=12356),
avrcp.VolumeChangedEvent(volume=9),
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
avrcp.AddressedPlayerChangedEvent(
@@ -581,6 +580,87 @@ async def test_get_supported_company_ids():
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_player_application_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: [
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.SINGLE_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.OFF,
],
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: [
avrcp.ApplicationSetting.ShuffleOnOffStatus.OFF,
avrcp.ApplicationSetting.ShuffleOnOffStatus.ALL_TRACKS_SHUFFLE,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
}
two_devices.protocols[1].delegate = avrcp.Delegate(
supported_player_app_settings=expected_settings
)
actual_settings = await two_devices.protocols[
0
].list_supported_player_app_settings()
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_set_player_app_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.SetPlayerApplicationSettingValueCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
],
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
),
)
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
}
assert delegate.player_app_settings == expected_settings
actual_settings = await two_devices.protocols[0].get_player_app_settings(
[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
)
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_play_item():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
with mock.patch.object(delegate, delegate.play_item.__name__) as play_item_mock:
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.PlayItemCommand(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
),
)
play_item_mock.assert_called_once_with(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_volume():
@@ -635,6 +715,102 @@ async def test_monitor_now_playing_content():
await anext(now_playing_iter)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_track_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.TRACK_CHANGED]
)
delegate.current_track_uid = avrcp.TrackChangedEvent.NO_TRACK
track_iter = two_devices.protocols[0].monitor_track_changed()
# Interim
assert (await anext(track_iter)) == avrcp.TrackChangedEvent.NO_TRACK
# Changed
two_devices.protocols[1].notify_track_changed(1)
assert (await anext(track_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_uid_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.UIDS_CHANGED]
)
delegate.uid_counter = 0
uid_iter = two_devices.protocols[0].monitor_uids()
# Interim
assert (await anext(uid_iter)) == 0
# Changed
two_devices.protocols[1].notify_uids_changed(1)
assert (await anext(uid_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_addressed_player():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.ADDRESSED_PLAYER_CHANGED]
)
delegate.uid_counter = 0
delegate.addressed_player_id = 0
addressed_player_iter = two_devices.protocols[0].monitor_addressed_player()
# Interim
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=0, uid_counter=0)
# Changed
two_devices.protocols[1].notify_addressed_player_changed(
avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
)
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_player_app_settings():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
supported_events=[avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED]
)
delegate.player_app_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
}
settings_iter = two_devices.protocols[0].monitor_player_application_settings()
# Interim
interim = await anext(settings_iter)
assert interim[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert (
interim[0].value_id
== avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
)
# Changed
two_devices.protocols[1].notify_player_application_settings_changed(
[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
)
]
)
changed = await anext(settings_iter)
assert changed[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert changed[0].value_id == avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frame_parser()

View File

@@ -73,6 +73,14 @@ def test_uuid_to_hex_str() -> None:
)
# -----------------------------------------------------------------------------
def test_uuid_hash() -> None:
uuid = UUID("1234")
uuid_128_bytes = UUID.from_bytes(uuid.to_bytes(force_128=True))
assert uuid in {uuid_128_bytes}
assert uuid_128_bytes in {uuid}
# -----------------------------------------------------------------------------
def test_appearance() -> None:
a = Appearance(Appearance.Category.COMPUTER, Appearance.ComputerSubcategory.LAPTOP)

View File

@@ -309,6 +309,27 @@ async def test_legacy_advertising_disconnection(auto_restart):
assert not devices[0].is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_le_multiple_connects():
devices = TwoDevices()
for controller in devices.controllers:
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
for dev in devices:
await dev.power_on()
await devices[0].start_advertising(auto_restart=True, advertising_interval_min=1.0)
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
await async_barrier()
await async_barrier()
# a second connection attempt is working
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_advertising_and_scanning():
@@ -445,7 +466,9 @@ async def test_get_remote_le_features():
devices = TwoDevices()
await devices.setup_connection()
assert (await devices.connections[0].get_remote_le_features()) is not None
assert (
await devices.connections[0].get_remote_le_features()
) == devices.controllers[1].le_features
# -----------------------------------------------------------------------------
@@ -803,6 +826,22 @@ async def test_remote_name_request():
assert actual_name == expected_name
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_remote_classic_features():
devices = TwoDevices()
devices[0].classic_enabled = True
devices[1].classic_enabled = True
await devices[0].power_on()
await devices[1].power_on()
connection = await devices[0].connect_classic(devices[1].public_address)
assert (
await asyncio.wait_for(connection.get_remote_classic_features(), _TIMEOUT)
== devices.controllers[1].lmp_features
)
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()

View File

@@ -22,6 +22,7 @@ import unittest.mock
import pytest
from bumble import controller, hci
from bumble.controller import Controller
from bumble.hci import (
HCI_AclDataPacket,
@@ -49,34 +50,27 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'supported_commands, lmp_features',
'supported_commands, max_lmp_features_page_number',
[
(
# Default commands
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000',
# Only LE LMP feature
'0000000060000000',
),
(controller.Controller.supported_commands, 0),
(
# All commands
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
set(hci.HCI_Command.command_names.keys()),
# 3 pages of LMP features
'000102030405060708090A0B0C0D0E0F011112131415161718191A1B1C1D1E1F',
2,
),
],
)
async def test_reset(supported_commands: str, lmp_features: str):
async def test_reset(supported_commands: set[int], max_lmp_features_page_number: int):
controller = Controller('C')
controller.supported_commands = bytes.fromhex(supported_commands)
controller.lmp_features = bytes.fromhex(lmp_features)
controller.supported_commands = supported_commands
controller.lmp_features_max_page_number = max_lmp_features_page_number
host = Host(controller, AsyncPipeSink(controller))
await host.reset()
assert host.local_lmp_features == int.from_bytes(
bytes.fromhex(lmp_features), 'little'
assert host.local_lmp_features == (
controller.lmp_features & ~(1 << (64 * max_lmp_features_page_number + 1))
)

View File

@@ -21,6 +21,7 @@ import logging
import os
import pathlib
import tempfile
from unittest import mock
import pytest
@@ -179,11 +180,55 @@ async def test_default_namespace(temporary_file):
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_filename(tmp_path):
import platformdirs
with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path):
# Case 1: no namespace, no filename
keystore = JsonKeyStore(None, None)
expected_directory = tmp_path / 'Pairing'
expected_filename = expected_directory / 'keys.json'
assert keystore.directory_name == expected_directory
assert keystore.filename == expected_filename
# Save some data
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
assert expected_filename.exists()
# Load back
keystore2 = JsonKeyStore(None, None)
foo = await keystore2.get('foo')
assert foo is not None
assert foo.ltk.value == ltk
# Case 2: namespace, no filename
keystore3 = JsonKeyStore('my:namespace', None)
# safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-')
expected_filename3 = expected_directory / 'my-namespace.json'
assert keystore3.filename == expected_filename3
# Save some data
await keystore3.update('bar', keys)
assert expected_filename3.exists()
# Load back
keystore4 = JsonKeyStore('my:namespace', None)
bar = await keystore4.get('bar')
assert bar is not None
assert bar.ltk.value == ltk
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
await test_no_filename()
# -----------------------------------------------------------------------------

View File

@@ -29,8 +29,7 @@ from bumble.gatt import Characteristic, Service
from bumble.hci import Role
from bumble.pairing import PairingConfig, PairingDelegate
from bumble.smp import (
SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_PAIRING_NOT_SUPPORTED_ERROR,
ErrorCode,
OobContext,
OobLegacyContext,
)
@@ -378,7 +377,7 @@ async def test_self_smp_reject():
await _test_self_smp_with_configs(None, rejecting_pairing_config)
paired = True
except ProtocolError as error:
assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
assert error.error_code == ErrorCode.PAIRING_NOT_SUPPORTED
assert not paired
@@ -403,7 +402,7 @@ async def test_self_smp_wrong_pin():
)
paired = True
except ProtocolError as error:
assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.error_code == ErrorCode.CONFIRM_VALUE_FAILED
assert not paired
@@ -534,11 +533,11 @@ async def test_self_smp_oob_sc():
with pytest.raises(ProtocolError) as error:
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
with pytest.raises(ProtocolError):
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
# -----------------------------------------------------------------------------

View File

@@ -24,7 +24,7 @@ import pytest
from bumble import crypto, pairing, smp
from bumble.core import AdvertisingData
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.device import Device
from bumble.device import Device, DeviceConfiguration
from bumble.hci import Address
from bumble.pairing import LeRole, OobData, OobSharedData
@@ -312,3 +312,17 @@ async def test_send_identity_address_command(
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
@pytest.mark.asyncio
async def test_smp_debug_mode():
config = DeviceConfiguration(smp_debug_mode=True)
device = Device(config=config)
assert device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
device.smp_manager.debug_mode = False
assert not device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert not device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y

View File

@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" />
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="heart_rate_monitor.js"></script>
<style>

View File

@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="scanner.css">
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="scanner.js"></script>
</style>

View File

@@ -4,7 +4,7 @@
<title>Bumble Speaker</title>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="speaker.css">
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script type="module" src="speaker.js"></script>
<script type="module" src="../ui.js"></script>
</head>