mirror of
https://github.com/google/bumble.git
synced 2026-06-06 08:22:27 +00:00
Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72d821b1f6 | |||
| afe064b4ea | |||
| 8d0cef70c2 | |||
| 9cefde1c3e | |||
| ffb9d5f117 | |||
| 7d3be8157a | |||
| 9dc9c348e5 | |||
| b18555539e | |||
| 8a853d5b2f | |||
| 8988a85245 | |||
| 0813da2278 | |||
| a1ff183d44 | |||
| 7adf44eddf | |||
| 05accbf805 | |||
| 80f54f2a09 | |||
| 07b5e33e09 | |||
| b874e26a4f | |||
| baa5257780 | |||
| a91ea9110c | |||
| 1686c5b11b | |||
| d9481992bb | |||
| 16d0ed56cf | |||
| c55eb156b8 | |||
| 8614881fb3 | |||
| 27d02ef18d | |||
| c0725e2a4a | |||
| bf0784dde4 | |||
| 444f43f6a3 | |||
| 2420c47cf1 | |||
| 0a78e7506b | |||
| f7cc6f6657 | |||
| f2824ee6b8 | |||
| 7188ef08de | |||
| 3ded9014d3 | |||
| b6125bdfb1 | |||
| dc17f4f1ca | |||
| 3f65380c20 | |||
| 25a0056ecc | |||
| 85f6b10983 | |||
| e85f041e9d | |||
| ee09e6f10d | |||
| c3daf4a7e1 | |||
| 3af623be7e | |||
| 4e76d3057b | |||
| eda7360222 | |||
| a4c15c00de | |||
| cba4df4aef | |||
| ceb8b448e9 | |||
| 311b716d5c | |||
| 0ba9e5c317 | |||
| 3517225b62 | |||
| ad4bb1578b | |||
| 4af65b381b | |||
| a5cd3365ae | |||
| 2915cb8bb6 | |||
| 28e485b7b3 | |||
| 1198f2c3f5 | |||
| 80aaf6a2b9 | |||
| eb64debb62 | |||
| c158f25b1e | |||
| 1330e83517 | |||
| d9c9bea6cb | |||
| 3b937631b3 | |||
| f8aa309111 | |||
| 673281ed71 | |||
| 3ac7af4683 | |||
| 5ebfaae74e | |||
| e6175f85fe | |||
| f9ba527508 | |||
| a407c4cabf | |||
| 6c2d6dddb5 | |||
| 797cd216d4 | |||
| e2e8c90e47 | |||
| 3d5648cdc3 | |||
| d810d93aaf | |||
| 81d9adb983 | |||
| 377fa896f7 | |||
| 79e5974946 | |||
| 657451474e | |||
| 9f730dce6f | |||
| 1a6be95a7e | |||
| aea5320d71 | |||
| 91cb1b1df3 | |||
| 81bdc86e52 | |||
| f23cad34e3 | |||
| 30fde2c00b | |||
| 256a1a7405 | |||
| 116d9b26bb | |||
| aabe2ca063 | |||
| 2d17a5f742 |
@@ -69,7 +69,7 @@ jobs:
|
|||||||
components: clippy,rustfmt
|
components: clippy,rustfmt
|
||||||
toolchain: ${{ matrix.rust-version }}
|
toolchain: ${{ matrix.rust-version }}
|
||||||
- name: Install Rust dependencies
|
- name: Install Rust dependencies
|
||||||
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
|
run: cargo install cargo-all-features --version 1.11.0 --locked # allows building/testing combinations of features
|
||||||
- name: Check License Headers
|
- name: Check License Headers
|
||||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||||
- name: Rust Build
|
- name: Rust Build
|
||||||
|
|||||||
+7
-2
@@ -24,13 +24,18 @@ import dataclasses
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
)
|
)
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomli
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
import tomllib
|
||||||
|
else:
|
||||||
|
import tomli as tomllib
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lc3 # type: ignore # pylint: disable=E0401
|
import lc3 # type: ignore # pylint: disable=E0401
|
||||||
@@ -114,7 +119,7 @@ def parse_broadcast_list(filename: str) -> Sequence[Broadcast]:
|
|||||||
broadcasts: list[Broadcast] = []
|
broadcasts: list[Broadcast] = []
|
||||||
|
|
||||||
with open(filename, "rb") as config_file:
|
with open(filename, "rb") as config_file:
|
||||||
config = tomli.load(config_file)
|
config = tomllib.load(config_file)
|
||||||
for broadcast in config.get("broadcasts", []):
|
for broadcast in config.get("broadcasts", []):
|
||||||
sources = []
|
sources = []
|
||||||
for source in broadcast.get("sources", []):
|
for source in broadcast.get("sources", []):
|
||||||
|
|||||||
+8
-7
@@ -20,11 +20,12 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from prompt_toolkit.shortcuts import PromptSession
|
from prompt_toolkit.shortcuts import PromptSession
|
||||||
|
|
||||||
from bumble import data_types
|
from bumble import data_types, smp
|
||||||
from bumble.a2dp import make_audio_sink_service_sdp_records
|
from bumble.a2dp import make_audio_sink_service_sdp_records
|
||||||
from bumble.att import (
|
from bumble.att import (
|
||||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
||||||
@@ -40,7 +41,7 @@ from bumble.core import (
|
|||||||
PhysicalTransport,
|
PhysicalTransport,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
)
|
)
|
||||||
from bumble.device import Device, Peer
|
from bumble.device import Connection, Device, Peer
|
||||||
from bumble.gatt import (
|
from bumble.gatt import (
|
||||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||||
GATT_GENERIC_ACCESS_SERVICE,
|
GATT_GENERIC_ACCESS_SERVICE,
|
||||||
@@ -53,7 +54,6 @@ from bumble.hci import OwnAddressType
|
|||||||
from bumble.keys import JsonKeyStore
|
from bumble.keys import JsonKeyStore
|
||||||
from bumble.pairing import OobData, PairingConfig, PairingDelegate
|
from bumble.pairing import OobData, PairingConfig, PairingDelegate
|
||||||
from bumble.smp import OobContext, OobLegacyContext
|
from bumble.smp import OobContext, OobLegacyContext
|
||||||
from bumble.smp import error_name as smp_error_name
|
|
||||||
from bumble.transport import open_transport
|
from bumble.transport import open_transport
|
||||||
from bumble.utils import AsyncRunner
|
from bumble.utils import AsyncRunner
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ POST_PAIRING_DELAY = 1
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Waiter:
|
class Waiter:
|
||||||
instance: Waiter | None = None
|
instance: ClassVar[Waiter | None] = None
|
||||||
|
|
||||||
def __init__(self, linger=False):
|
def __init__(self, linger=False):
|
||||||
self.done = asyncio.get_running_loop().create_future()
|
self.done = asyncio.get_running_loop().create_future()
|
||||||
@@ -319,12 +319,13 @@ async def on_classic_pairing(connection):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@AsyncRunner.run_in_task()
|
@AsyncRunner.run_in_task()
|
||||||
async def on_pairing_failure(connection, reason):
|
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode):
|
||||||
print(color('***-----------------------------------', 'red'))
|
print(color('***-----------------------------------', 'red'))
|
||||||
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
|
print(color(f'*** Pairing failed: {reason.name}', 'red'))
|
||||||
print(color('***-----------------------------------', 'red'))
|
print(color('***-----------------------------------', 'red'))
|
||||||
await connection.disconnect()
|
await connection.disconnect()
|
||||||
Waiter.instance.terminate()
|
if Waiter.instance:
|
||||||
|
Waiter.instance.terminate()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
+20
-38
@@ -88,13 +88,6 @@ SBC_DUAL_CHANNEL_MODE = 0x01
|
|||||||
SBC_STEREO_CHANNEL_MODE = 0x02
|
SBC_STEREO_CHANNEL_MODE = 0x02
|
||||||
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
|
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
|
||||||
|
|
||||||
SBC_CHANNEL_MODE_NAMES = {
|
|
||||||
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
|
|
||||||
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
|
|
||||||
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
|
|
||||||
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
|
|
||||||
}
|
|
||||||
|
|
||||||
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
|
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
|
||||||
|
|
||||||
SBC_SUBBANDS = [4, 8]
|
SBC_SUBBANDS = [4, 8]
|
||||||
@@ -102,11 +95,6 @@ SBC_SUBBANDS = [4, 8]
|
|||||||
SBC_SNR_ALLOCATION_METHOD = 0x00
|
SBC_SNR_ALLOCATION_METHOD = 0x00
|
||||||
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
|
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
|
||||||
|
|
||||||
SBC_ALLOCATION_METHOD_NAMES = {
|
|
||||||
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
|
|
||||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
|
||||||
}
|
|
||||||
|
|
||||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||||
|
|
||||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||||
@@ -129,13 +117,6 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
|
|||||||
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
|
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
|
||||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
|
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
|
||||||
|
|
||||||
MPEG_2_4_OBJECT_TYPE_NAMES = {
|
|
||||||
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
|
|
||||||
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
|
|
||||||
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
|
|
||||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||||
|
|
||||||
@@ -267,26 +248,27 @@ class MediaCodecInformation:
|
|||||||
def create(
|
def create(
|
||||||
cls, media_codec_type: int, data: bytes
|
cls, media_codec_type: int, data: bytes
|
||||||
) -> MediaCodecInformation | bytes:
|
) -> MediaCodecInformation | bytes:
|
||||||
if media_codec_type == CodecType.SBC:
|
match media_codec_type:
|
||||||
return SbcMediaCodecInformation.from_bytes(data)
|
case CodecType.SBC:
|
||||||
elif media_codec_type == CodecType.MPEG_2_4_AAC:
|
return SbcMediaCodecInformation.from_bytes(data)
|
||||||
return AacMediaCodecInformation.from_bytes(data)
|
case CodecType.MPEG_2_4_AAC:
|
||||||
elif media_codec_type == CodecType.NON_A2DP:
|
return AacMediaCodecInformation.from_bytes(data)
|
||||||
vendor_media_codec_information = (
|
case CodecType.NON_A2DP:
|
||||||
VendorSpecificMediaCodecInformation.from_bytes(data)
|
vendor_media_codec_information = (
|
||||||
)
|
VendorSpecificMediaCodecInformation.from_bytes(data)
|
||||||
if (
|
|
||||||
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
|
||||||
vendor_media_codec_information.vendor_id
|
|
||||||
)
|
|
||||||
) and (
|
|
||||||
media_codec_information_class := vendor_class_map.get(
|
|
||||||
vendor_media_codec_information.codec_id
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return media_codec_information_class.from_bytes(
|
|
||||||
vendor_media_codec_information.value
|
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
||||||
|
vendor_media_codec_information.vendor_id
|
||||||
|
)
|
||||||
|
) and (
|
||||||
|
media_codec_information_class := vendor_class_map.get(
|
||||||
|
vendor_media_codec_information.codec_id
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return media_codec_information_class.from_bytes(
|
||||||
|
vendor_media_codec_information.value
|
||||||
|
)
|
||||||
return vendor_media_codec_information
|
return vendor_media_codec_information
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
+32
-30
@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
|
|||||||
are ignored [..], unless they are embedded in numeric or string constants"
|
are ignored [..], unless they are embedded in numeric or string constants"
|
||||||
Raises AtParsingError in case of invalid input string."""
|
Raises AtParsingError in case of invalid input string."""
|
||||||
|
|
||||||
tokens = []
|
tokens: list[bytearray] = []
|
||||||
in_quotes = False
|
in_quotes = False
|
||||||
token = bytearray()
|
token = bytearray()
|
||||||
for b in buffer:
|
for b in buffer:
|
||||||
@@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
|
|||||||
tokens.append(token[1:-1])
|
tokens.append(token[1:-1])
|
||||||
token = bytearray()
|
token = bytearray()
|
||||||
else:
|
else:
|
||||||
if char == b' ':
|
match char:
|
||||||
pass
|
case b' ':
|
||||||
elif char == b',' or char == b')':
|
pass
|
||||||
tokens.append(token)
|
case b',' | b')':
|
||||||
tokens.append(char)
|
tokens.append(token)
|
||||||
token = bytearray()
|
tokens.append(char)
|
||||||
elif char == b'(':
|
token = bytearray()
|
||||||
if len(token) > 0:
|
case b'(':
|
||||||
raise AtParsingError("open_paren following regular character")
|
if len(token) > 0:
|
||||||
tokens.append(char)
|
raise AtParsingError("open_paren following regular character")
|
||||||
elif char == b'"':
|
tokens.append(char)
|
||||||
if len(token) > 0:
|
case b'"':
|
||||||
raise AtParsingError("quote following regular character")
|
if len(token) > 0:
|
||||||
in_quotes = True
|
raise AtParsingError("quote following regular character")
|
||||||
token.extend(char)
|
in_quotes = True
|
||||||
else:
|
token.extend(char)
|
||||||
token.extend(char)
|
case _:
|
||||||
|
token.extend(char)
|
||||||
|
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return [bytes(token) for token in tokens if len(token) > 0]
|
return [bytes(token) for token in tokens if len(token) > 0]
|
||||||
@@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
|
|||||||
current: bytes | list = b''
|
current: bytes | list = b''
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token == b',':
|
match token:
|
||||||
accumulator[-1].append(current)
|
case b',':
|
||||||
current = b''
|
accumulator[-1].append(current)
|
||||||
elif token == b'(':
|
current = b''
|
||||||
accumulator.append([])
|
case b'(':
|
||||||
elif token == b')':
|
accumulator.append([])
|
||||||
if len(accumulator) < 2:
|
case b')':
|
||||||
raise AtParsingError("close_paren without matching open_paren")
|
if len(accumulator) < 2:
|
||||||
accumulator[-1].append(current)
|
raise AtParsingError("close_paren without matching open_paren")
|
||||||
current = accumulator.pop()
|
accumulator[-1].append(current)
|
||||||
else:
|
current = accumulator.pop()
|
||||||
current = token
|
case _:
|
||||||
|
current = token
|
||||||
|
|
||||||
accumulator[-1].append(current)
|
accumulator[-1].append(current)
|
||||||
if len(accumulator) > 1:
|
if len(accumulator) > 1:
|
||||||
|
|||||||
+57
-52
@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
|
|||||||
|
|
||||||
from bumble import hci, l2cap, utils
|
from bumble import hci, l2cap, utils
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.core import UUID, InvalidOperationError, ProtocolError
|
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
|
||||||
from bumble.hci import HCI_Object
|
from bumble.hci import HCI_Object
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -249,6 +249,8 @@ class ATT_PDU:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
|
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
|
||||||
|
if not pdu:
|
||||||
|
raise InvalidPacketError("Empty ATT PDU")
|
||||||
op_code = pdu[0]
|
op_code = pdu[0]
|
||||||
|
|
||||||
subclass = ATT_PDU.pdu_classes.get(op_code)
|
subclass = ATT_PDU.pdu_classes.get(op_code)
|
||||||
@@ -954,12 +956,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
|||||||
self.permissions = permissions
|
self.permissions = permissions
|
||||||
|
|
||||||
# Convert the type to a UUID object if it isn't already
|
# Convert the type to a UUID object if it isn't already
|
||||||
if isinstance(attribute_type, str):
|
match attribute_type:
|
||||||
self.type = UUID(attribute_type)
|
case str():
|
||||||
elif isinstance(attribute_type, bytes):
|
self.type = UUID(attribute_type)
|
||||||
self.type = UUID.from_bytes(attribute_type)
|
case bytes():
|
||||||
else:
|
self.type = UUID.from_bytes(attribute_type)
|
||||||
self.type = attribute_type
|
case _:
|
||||||
|
self.type = attribute_type
|
||||||
|
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@@ -994,30 +997,31 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
value: _T | None
|
value: _T | None
|
||||||
if isinstance(self.value, AttributeValue):
|
match self.value:
|
||||||
try:
|
case AttributeValue():
|
||||||
read_value = self.value.read(connection)
|
try:
|
||||||
if inspect.isawaitable(read_value):
|
read_value = self.value.read(connection)
|
||||||
value = await read_value
|
if inspect.isawaitable(read_value):
|
||||||
else:
|
value = await read_value
|
||||||
value = read_value
|
else:
|
||||||
except ATT_Error as error:
|
value = read_value
|
||||||
raise ATT_Error(
|
except ATT_Error as error:
|
||||||
error_code=error.error_code, att_handle=self.handle
|
raise ATT_Error(
|
||||||
) from error
|
error_code=error.error_code, att_handle=self.handle
|
||||||
elif isinstance(self.value, AttributeValueV2):
|
) from error
|
||||||
try:
|
case AttributeValueV2():
|
||||||
read_value = self.value.read(bearer)
|
try:
|
||||||
if inspect.isawaitable(read_value):
|
read_value = self.value.read(bearer)
|
||||||
value = await read_value
|
if inspect.isawaitable(read_value):
|
||||||
else:
|
value = await read_value
|
||||||
value = read_value
|
else:
|
||||||
except ATT_Error as error:
|
value = read_value
|
||||||
raise ATT_Error(
|
except ATT_Error as error:
|
||||||
error_code=error.error_code, att_handle=self.handle
|
raise ATT_Error(
|
||||||
) from error
|
error_code=error.error_code, att_handle=self.handle
|
||||||
else:
|
) from error
|
||||||
value = self.value
|
case _:
|
||||||
|
value = self.value
|
||||||
|
|
||||||
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
|
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
|
||||||
|
|
||||||
@@ -1049,26 +1053,27 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
|||||||
|
|
||||||
decoded_value = self.decode_value(value)
|
decoded_value = self.decode_value(value)
|
||||||
|
|
||||||
if isinstance(self.value, AttributeValue):
|
match self.value:
|
||||||
try:
|
case AttributeValue():
|
||||||
result = self.value.write(connection, decoded_value)
|
try:
|
||||||
if inspect.isawaitable(result):
|
result = self.value.write(connection, decoded_value)
|
||||||
await result
|
if inspect.isawaitable(result):
|
||||||
except ATT_Error as error:
|
await result
|
||||||
raise ATT_Error(
|
except ATT_Error as error:
|
||||||
error_code=error.error_code, att_handle=self.handle
|
raise ATT_Error(
|
||||||
) from error
|
error_code=error.error_code, att_handle=self.handle
|
||||||
elif isinstance(self.value, AttributeValueV2):
|
) from error
|
||||||
try:
|
case AttributeValueV2():
|
||||||
result = self.value.write(bearer, decoded_value)
|
try:
|
||||||
if inspect.isawaitable(result):
|
result = self.value.write(bearer, decoded_value)
|
||||||
await result
|
if inspect.isawaitable(result):
|
||||||
except ATT_Error as error:
|
await result
|
||||||
raise ATT_Error(
|
except ATT_Error as error:
|
||||||
error_code=error.error_code, att_handle=self.handle
|
raise ATT_Error(
|
||||||
) from error
|
error_code=error.error_code, att_handle=self.handle
|
||||||
else:
|
) from error
|
||||||
self.value = decoded_value
|
case _:
|
||||||
|
self.value = decoded_value
|
||||||
|
|
||||||
self.emit(self.EVENT_WRITE, connection, decoded_value)
|
self.emit(self.EVENT_WRITE, connection, decoded_value)
|
||||||
|
|
||||||
@@ -1078,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
|||||||
else:
|
else:
|
||||||
value_str = str(self.value)
|
value_str = str(self.value)
|
||||||
if value_str:
|
if value_str:
|
||||||
value_string = f', value={self.value.hex()}'
|
value_string = f', value={value_str}'
|
||||||
else:
|
else:
|
||||||
value_string = ''
|
value_string = ''
|
||||||
return (
|
return (
|
||||||
|
|||||||
+140
-77
@@ -17,6 +17,7 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
@@ -311,6 +312,13 @@ class MessageAssembler:
|
|||||||
def on_pdu(self, pdu: bytes) -> None:
|
def on_pdu(self, pdu: bytes) -> None:
|
||||||
self.packet_count += 1
|
self.packet_count += 1
|
||||||
|
|
||||||
|
# Drop empty PDUs sent by remote — accessing pdu[0] below would
|
||||||
|
# raise IndexError, propagating up to the L2CAP read loop and
|
||||||
|
# tearing down the channel. Same class as #912 (ATT empty PDU).
|
||||||
|
if not pdu:
|
||||||
|
logger.warning('AVDTP message assembler: empty PDU dropped')
|
||||||
|
return
|
||||||
|
|
||||||
transaction_label = pdu[0] >> 4
|
transaction_label = pdu[0] >> 4
|
||||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||||
message_type = Message.MessageType(pdu[0] & 3)
|
message_type = Message.MessageType(pdu[0] & 3)
|
||||||
@@ -324,6 +332,23 @@ class MessageAssembler:
|
|||||||
Protocol.PacketType.SINGLE_PACKET,
|
Protocol.PacketType.SINGLE_PACKET,
|
||||||
Protocol.PacketType.START_PACKET,
|
Protocol.PacketType.START_PACKET,
|
||||||
):
|
):
|
||||||
|
# Both single and start packets carry the signal identifier in
|
||||||
|
# pdu[1]; start packets additionally carry the packet count in
|
||||||
|
# pdu[2]. Guard each access so a malformed remote frame can't
|
||||||
|
# crash the message assembler.
|
||||||
|
if len(pdu) < 2:
|
||||||
|
logger.warning(
|
||||||
|
'AVDTP %s packet too short (%d bytes); dropped',
|
||||||
|
packet_type.name,
|
||||||
|
len(pdu),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
|
||||||
|
logger.warning(
|
||||||
|
'AVDTP START packet missing signal-packet count; dropped'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if self.message is not None:
|
if self.message is not None:
|
||||||
# The previous message has not been terminated
|
# The previous message has not been terminated
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1453,8 +1478,23 @@ class Protocol(utils.EventEmitter):
|
|||||||
handler = getattr(self, handler_name, None)
|
handler = getattr(self, handler_name, None)
|
||||||
if handler:
|
if handler:
|
||||||
try:
|
try:
|
||||||
response = handler(message)
|
result = handler(message)
|
||||||
self.send_message(transaction_label, response)
|
if asyncio.iscoroutine(result):
|
||||||
|
|
||||||
|
async def wait_and_send() -> None:
|
||||||
|
try:
|
||||||
|
response = await result
|
||||||
|
if response:
|
||||||
|
self.send_message(transaction_label, response)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
color("!!! Exception in handler:", "red")
|
||||||
|
)
|
||||||
|
|
||||||
|
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
|
||||||
|
else:
|
||||||
|
if result:
|
||||||
|
self.send_message(transaction_label, result)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(color("!!! Exception in handler:", "red"))
|
logger.exception(color("!!! Exception in handler:", "red"))
|
||||||
else:
|
else:
|
||||||
@@ -1535,7 +1575,7 @@ class Protocol(utils.EventEmitter):
|
|||||||
async def send_command(self, command: Message):
|
async def send_command(self, command: Message):
|
||||||
# TODO: support timeouts
|
# TODO: support timeouts
|
||||||
# Send the command
|
# Send the command
|
||||||
(transaction_label, transaction_result) = await self.start_transaction()
|
transaction_label, transaction_result = await self.start_transaction()
|
||||||
self.send_message(transaction_label, command)
|
self.send_message(transaction_label, command)
|
||||||
|
|
||||||
# Wait for the response
|
# Wait for the response
|
||||||
@@ -1600,14 +1640,14 @@ class Protocol(utils.EventEmitter):
|
|||||||
async def abort(self, seid: int) -> Abort_Response:
|
async def abort(self, seid: int) -> Abort_Response:
|
||||||
return await self.send_command(Abort_Command(seid))
|
return await self.send_command(Abort_Command(seid))
|
||||||
|
|
||||||
def on_discover_command(self, command: Discover_Command) -> Message | None:
|
async def on_discover_command(self, command: Discover_Command) -> Message | None:
|
||||||
endpoint_infos = [
|
endpoint_infos = [
|
||||||
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
|
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
|
||||||
for endpoint in self.local_endpoints
|
for endpoint in self.local_endpoints
|
||||||
]
|
]
|
||||||
return Discover_Response(endpoint_infos)
|
return Discover_Response(endpoint_infos)
|
||||||
|
|
||||||
def on_get_capabilities_command(
|
async def on_get_capabilities_command(
|
||||||
self, command: Get_Capabilities_Command
|
self, command: Get_Capabilities_Command
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
@@ -1616,7 +1656,7 @@ class Protocol(utils.EventEmitter):
|
|||||||
|
|
||||||
return Get_Capabilities_Response(endpoint.capabilities)
|
return Get_Capabilities_Response(endpoint.capabilities)
|
||||||
|
|
||||||
def on_get_all_capabilities_command(
|
async def on_get_all_capabilities_command(
|
||||||
self, command: Get_All_Capabilities_Command
|
self, command: Get_All_Capabilities_Command
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
@@ -1625,7 +1665,7 @@ class Protocol(utils.EventEmitter):
|
|||||||
|
|
||||||
return Get_All_Capabilities_Response(endpoint.capabilities)
|
return Get_All_Capabilities_Response(endpoint.capabilities)
|
||||||
|
|
||||||
def on_set_configuration_command(
|
async def on_set_configuration_command(
|
||||||
self, command: Set_Configuration_Command
|
self, command: Set_Configuration_Command
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
@@ -1640,10 +1680,10 @@ class Protocol(utils.EventEmitter):
|
|||||||
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
|
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
|
||||||
self.streams[command.acp_seid] = stream
|
self.streams[command.acp_seid] = stream
|
||||||
|
|
||||||
result = stream.on_set_configuration_command(command.capabilities)
|
result = await stream.on_set_configuration_command(command.capabilities)
|
||||||
return result or Set_Configuration_Response()
|
return result or Set_Configuration_Response()
|
||||||
|
|
||||||
def on_get_configuration_command(
|
async def on_get_configuration_command(
|
||||||
self, command: Get_Configuration_Command
|
self, command: Get_Configuration_Command
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
@@ -1652,29 +1692,31 @@ class Protocol(utils.EventEmitter):
|
|||||||
if endpoint.stream is None:
|
if endpoint.stream is None:
|
||||||
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
|
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
return endpoint.stream.on_get_configuration_command()
|
return await endpoint.stream.on_get_configuration_command()
|
||||||
|
|
||||||
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
|
async def on_reconfigure_command(
|
||||||
|
self, command: Reconfigure_Command
|
||||||
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
|
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
|
||||||
if endpoint.stream is None:
|
if endpoint.stream is None:
|
||||||
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = endpoint.stream.on_reconfigure_command(command.capabilities)
|
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
|
||||||
return result or Reconfigure_Response()
|
return result or Reconfigure_Response()
|
||||||
|
|
||||||
def on_open_command(self, command: Open_Command) -> Message | None:
|
async def on_open_command(self, command: Open_Command) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||||
if endpoint.stream is None:
|
if endpoint.stream is None:
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = endpoint.stream.on_open_command()
|
result = await endpoint.stream.on_open_command()
|
||||||
return result or Open_Response()
|
return result or Open_Response()
|
||||||
|
|
||||||
def on_start_command(self, command: Start_Command) -> Message | None:
|
async def on_start_command(self, command: Start_Command) -> Message | None:
|
||||||
for seid in command.acp_seids:
|
for seid in command.acp_seids:
|
||||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
@@ -1688,12 +1730,12 @@ class Protocol(utils.EventEmitter):
|
|||||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||||
if not endpoint or not endpoint.stream:
|
if not endpoint or not endpoint.stream:
|
||||||
raise InvalidStateError("Should already be checked!")
|
raise InvalidStateError("Should already be checked!")
|
||||||
if (result := endpoint.stream.on_start_command()) is not None:
|
if (result := await endpoint.stream.on_start_command()) is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return Start_Response()
|
return Start_Response()
|
||||||
|
|
||||||
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
|
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
|
||||||
for seid in command.acp_seids:
|
for seid in command.acp_seids:
|
||||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
@@ -1707,45 +1749,47 @@ class Protocol(utils.EventEmitter):
|
|||||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||||
if not endpoint or not endpoint.stream:
|
if not endpoint or not endpoint.stream:
|
||||||
raise InvalidStateError("Should already be checked!")
|
raise InvalidStateError("Should already be checked!")
|
||||||
if (result := endpoint.stream.on_suspend_command()) is not None:
|
if (result := await endpoint.stream.on_suspend_command()) is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return Suspend_Response()
|
return Suspend_Response()
|
||||||
|
|
||||||
def on_close_command(self, command: Close_Command) -> Message | None:
|
async def on_close_command(self, command: Close_Command) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||||
if endpoint.stream is None:
|
if endpoint.stream is None:
|
||||||
return Close_Reject(AVDTP_BAD_STATE_ERROR)
|
return Close_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = endpoint.stream.on_close_command()
|
result = await endpoint.stream.on_close_command()
|
||||||
return result or Close_Response()
|
return result or Close_Response()
|
||||||
|
|
||||||
def on_abort_command(self, command: Abort_Command) -> Message | None:
|
async def on_abort_command(self, command: Abort_Command) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None or endpoint.stream is None:
|
if endpoint is None or endpoint.stream is None:
|
||||||
return Abort_Response()
|
return Abort_Response()
|
||||||
|
|
||||||
endpoint.stream.on_abort_command()
|
await endpoint.stream.on_abort_command()
|
||||||
return Abort_Response()
|
return Abort_Response()
|
||||||
|
|
||||||
def on_security_control_command(
|
async def on_security_control_command(
|
||||||
self, command: Security_Control_Command
|
self, command: Security_Control_Command
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||||
|
|
||||||
result = endpoint.on_security_control_command(command.data)
|
result = await endpoint.on_security_control_command(command.data)
|
||||||
return result or Security_Control_Response()
|
return result or Security_Control_Response()
|
||||||
|
|
||||||
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
|
async def on_delayreport_command(
|
||||||
|
self, command: DelayReport_Command
|
||||||
|
) -> Message | None:
|
||||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||||
|
|
||||||
result = endpoint.on_delayreport_command(command.delay)
|
result = await endpoint.on_delayreport_command(command.delay)
|
||||||
return result or DelayReport_Response()
|
return result or DelayReport_Response()
|
||||||
|
|
||||||
|
|
||||||
@@ -1903,25 +1947,22 @@ class Stream:
|
|||||||
await self.rtp_channel.disconnect()
|
await self.rtp_channel.disconnect()
|
||||||
self.rtp_channel = None
|
self.rtp_channel = None
|
||||||
|
|
||||||
# Release the endpoint
|
|
||||||
self.local_endpoint.in_use = 0
|
|
||||||
|
|
||||||
self.change_state(State.IDLE)
|
self.change_state(State.IDLE)
|
||||||
|
|
||||||
def on_set_configuration_command(
|
async def on_set_configuration_command(
|
||||||
self, configuration: Iterable[ServiceCapabilities]
|
self, configuration: Iterable[ServiceCapabilities]
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
if self.state != State.IDLE:
|
if self.state != State.IDLE:
|
||||||
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_set_configuration_command(configuration)
|
result = await self.local_endpoint.on_set_configuration_command(configuration)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.change_state(State.CONFIGURED)
|
self.change_state(State.CONFIGURED)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_get_configuration_command(self) -> Message | None:
|
async def on_get_configuration_command(self) -> Message | None:
|
||||||
if self.state not in (
|
if self.state not in (
|
||||||
State.CONFIGURED,
|
State.CONFIGURED,
|
||||||
State.OPEN,
|
State.OPEN,
|
||||||
@@ -1929,25 +1970,25 @@ class Stream:
|
|||||||
):
|
):
|
||||||
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
return self.local_endpoint.on_get_configuration_command()
|
return await self.local_endpoint.on_get_configuration_command()
|
||||||
|
|
||||||
def on_reconfigure_command(
|
async def on_reconfigure_command(
|
||||||
self, configuration: Iterable[ServiceCapabilities]
|
self, configuration: Iterable[ServiceCapabilities]
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
if self.state != State.OPEN:
|
if self.state != State.OPEN:
|
||||||
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_reconfigure_command(configuration)
|
result = await self.local_endpoint.on_reconfigure_command(configuration)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_open_command(self) -> Message | None:
|
async def on_open_command(self) -> Message | None:
|
||||||
if self.state != State.CONFIGURED:
|
if self.state != State.CONFIGURED:
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_open_command()
|
result = await self.local_endpoint.on_open_command()
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1957,7 +1998,7 @@ class Stream:
|
|||||||
self.change_state(State.OPEN)
|
self.change_state(State.OPEN)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_start_command(self) -> Message | None:
|
async def on_start_command(self) -> Message | None:
|
||||||
if self.state != State.OPEN:
|
if self.state != State.OPEN:
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
@@ -1966,29 +2007,29 @@ class Stream:
|
|||||||
logger.warning('received start command before RTP channel establishment')
|
logger.warning('received start command before RTP channel establishment')
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_start_command()
|
result = await self.local_endpoint.on_start_command()
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.change_state(State.STREAMING)
|
self.change_state(State.STREAMING)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_suspend_command(self) -> Message | None:
|
async def on_suspend_command(self) -> Message | None:
|
||||||
if self.state != State.STREAMING:
|
if self.state != State.STREAMING:
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_suspend_command()
|
result = await self.local_endpoint.on_suspend_command()
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.change_state(State.OPEN)
|
self.change_state(State.OPEN)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_close_command(self) -> Message | None:
|
async def on_close_command(self) -> Message | None:
|
||||||
if self.state not in (State.OPEN, State.STREAMING):
|
if self.state not in (State.OPEN, State.STREAMING):
|
||||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||||
|
|
||||||
result = self.local_endpoint.on_close_command()
|
result = await self.local_endpoint.on_close_command()
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -2003,7 +2044,8 @@ class Stream:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_abort_command(self) -> Message | None:
|
async def on_abort_command(self) -> Message | None:
|
||||||
|
await self.local_endpoint.on_abort_command()
|
||||||
if self.rtp_channel is None:
|
if self.rtp_channel is None:
|
||||||
# No need to wait
|
# No need to wait
|
||||||
self.change_state(State.IDLE)
|
self.change_state(State.IDLE)
|
||||||
@@ -2028,7 +2070,6 @@ class Stream:
|
|||||||
def on_l2cap_channel_close(self) -> None:
|
def on_l2cap_channel_close(self) -> None:
|
||||||
logger.debug(color('<<< stream channel closed', 'magenta'))
|
logger.debug(color('<<< stream channel closed', 'magenta'))
|
||||||
self.local_endpoint.on_rtp_channel_close()
|
self.local_endpoint.on_rtp_channel_close()
|
||||||
self.local_endpoint.in_use = 0
|
|
||||||
self.rtp_channel = None
|
self.rtp_channel = None
|
||||||
|
|
||||||
if self.state in (State.CLOSING, State.ABORTING):
|
if self.state in (State.CLOSING, State.ABORTING):
|
||||||
@@ -2053,7 +2094,6 @@ class Stream:
|
|||||||
self.state = State.IDLE
|
self.state = State.IDLE
|
||||||
|
|
||||||
local_endpoint.stream = self
|
local_endpoint.stream = self
|
||||||
local_endpoint.in_use = 1
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -2063,14 +2103,16 @@ class Stream:
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@dataclass
|
class StreamEndPoint(abc.ABC):
|
||||||
class StreamEndPoint:
|
|
||||||
seid: int
|
seid: int
|
||||||
media_type: MediaType
|
media_type: MediaType
|
||||||
tsep: StreamEndPointType
|
tsep: StreamEndPointType
|
||||||
in_use: int
|
|
||||||
capabilities: Iterable[ServiceCapabilities]
|
capabilities: Iterable[ServiceCapabilities]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_use(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class StreamEndPointProxy:
|
class StreamEndPointProxy:
|
||||||
@@ -2110,14 +2152,30 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
|
|||||||
in_use: int,
|
in_use: int,
|
||||||
capabilities: Iterable[ServiceCapabilities],
|
capabilities: Iterable[ServiceCapabilities],
|
||||||
) -> None:
|
) -> None:
|
||||||
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
|
# StreamEndPoint attributes
|
||||||
StreamEndPointProxy.__init__(self, protocol, seid)
|
self.seid = seid
|
||||||
|
self.media_type = media_type
|
||||||
|
self.tsep = tsep
|
||||||
|
self._in_use = in_use
|
||||||
|
self.capabilities = capabilities
|
||||||
|
|
||||||
|
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_use(self) -> int:
|
||||||
|
return self._in_use
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||||
stream: Stream | None
|
stream: Stream | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_use(self) -> int:
|
||||||
|
if self.stream and self.stream.state != State.IDLE:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
EVENT_CONFIGURATION = "configuration"
|
EVENT_CONFIGURATION = "configuration"
|
||||||
EVENT_OPEN = "open"
|
EVENT_OPEN = "open"
|
||||||
EVENT_START = "start"
|
EVENT_START = "start"
|
||||||
@@ -2140,8 +2198,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
|||||||
capabilities: Iterable[ServiceCapabilities],
|
capabilities: Iterable[ServiceCapabilities],
|
||||||
configuration: Iterable[ServiceCapabilities] | None = None,
|
configuration: Iterable[ServiceCapabilities] | None = None,
|
||||||
):
|
):
|
||||||
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
|
|
||||||
utils.EventEmitter.__init__(self)
|
utils.EventEmitter.__init__(self)
|
||||||
|
# StreamEndPoint attributes
|
||||||
|
self.seid = seid
|
||||||
|
self.media_type = media_type
|
||||||
|
self.tsep = tsep
|
||||||
|
self.capabilities = capabilities
|
||||||
|
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.configuration = configuration if configuration is not None else []
|
self.configuration = configuration if configuration is not None else []
|
||||||
self.stream = None
|
self.stream = None
|
||||||
@@ -2155,13 +2218,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
|||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""[Source Only] Handles when receiving close command."""
|
"""[Source Only] Handles when receiving close command."""
|
||||||
|
|
||||||
def on_reconfigure_command(
|
async def on_reconfigure_command(
|
||||||
self, command: Iterable[ServiceCapabilities]
|
self, command: Iterable[ServiceCapabilities]
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
del command # unused.
|
del command # unused.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_set_configuration_command(
|
async def on_set_configuration_command(
|
||||||
self, configuration: Iterable[ServiceCapabilities]
|
self, configuration: Iterable[ServiceCapabilities]
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -2172,34 +2235,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
|||||||
self.emit(self.EVENT_CONFIGURATION)
|
self.emit(self.EVENT_CONFIGURATION)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_get_configuration_command(self) -> Message | None:
|
async def on_get_configuration_command(self) -> Message | None:
|
||||||
return Get_Configuration_Response(self.configuration)
|
return Get_Configuration_Response(self.configuration)
|
||||||
|
|
||||||
def on_open_command(self) -> Message | None:
|
async def on_open_command(self) -> Message | None:
|
||||||
self.emit(self.EVENT_OPEN)
|
self.emit(self.EVENT_OPEN)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_start_command(self) -> Message | None:
|
async def on_start_command(self) -> Message | None:
|
||||||
self.emit(self.EVENT_START)
|
self.emit(self.EVENT_START)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_suspend_command(self) -> Message | None:
|
async def on_suspend_command(self) -> Message | None:
|
||||||
self.emit(self.EVENT_SUSPEND)
|
self.emit(self.EVENT_SUSPEND)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_close_command(self) -> Message | None:
|
async def on_close_command(self) -> Message | None:
|
||||||
self.emit(self.EVENT_CLOSE)
|
self.emit(self.EVENT_CLOSE)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_abort_command(self) -> Message | None:
|
async def on_abort_command(self) -> Message | None:
|
||||||
self.emit(self.EVENT_ABORT)
|
self.emit(self.EVENT_ABORT)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_delayreport_command(self, delay: int) -> Message | None:
|
async def on_delayreport_command(self, delay: int) -> Message | None:
|
||||||
self.emit(self.EVENT_DELAY_REPORT, delay)
|
self.emit(self.EVENT_DELAY_REPORT, delay)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_security_control_command(self, data: bytes) -> Message | None:
|
async def on_security_control_command(self, data: bytes) -> Message | None:
|
||||||
self.emit(self.EVENT_SECURITY_CONTROL, data)
|
self.emit(self.EVENT_SECURITY_CONTROL, data)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -2227,12 +2290,12 @@ class LocalSource(LocalStreamEndPoint):
|
|||||||
codec_capabilities,
|
codec_capabilities,
|
||||||
] + list(other_capabilities)
|
] + list(other_capabilities)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
protocol,
|
protocol=protocol,
|
||||||
seid,
|
seid=seid,
|
||||||
codec_capabilities.media_type,
|
media_type=codec_capabilities.media_type,
|
||||||
AVDTP_TSEP_SRC,
|
tsep=AVDTP_TSEP_SRC,
|
||||||
capabilities,
|
capabilities=capabilities,
|
||||||
capabilities,
|
configuration=capabilities,
|
||||||
)
|
)
|
||||||
self.packet_pump = packet_pump
|
self.packet_pump = packet_pump
|
||||||
|
|
||||||
@@ -2251,13 +2314,13 @@ class LocalSource(LocalStreamEndPoint):
|
|||||||
self.emit(self.EVENT_STOP)
|
self.emit(self.EVENT_STOP)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_start_command(self) -> Message | None:
|
async def on_start_command(self) -> Message | None:
|
||||||
asyncio.create_task(self.start())
|
await self.start()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_suspend_command(self) -> Message | None:
|
async def on_suspend_command(self) -> Message | None:
|
||||||
asyncio.create_task(self.stop())
|
await self.stop()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -2271,11 +2334,11 @@ class LocalSink(LocalStreamEndPoint):
|
|||||||
codec_capabilities,
|
codec_capabilities,
|
||||||
]
|
]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
protocol,
|
protocol=protocol,
|
||||||
seid,
|
seid=seid,
|
||||||
codec_capabilities.media_type,
|
media_type=codec_capabilities.media_type,
|
||||||
AVDTP_TSEP_SNK,
|
tsep=AVDTP_TSEP_SNK,
|
||||||
capabilities,
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_rtp_channel_open(self) -> None:
|
def on_rtp_channel_open(self) -> None:
|
||||||
|
|||||||
+344
-87
@@ -22,7 +22,14 @@ import enum
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
from collections.abc import (
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import ClassVar, SupportsBytes, TypeVar
|
from typing import ClassVar, SupportsBytes, TypeVar
|
||||||
|
|
||||||
@@ -1049,11 +1056,9 @@ class GetItemAttributesCommand(Command):
|
|||||||
scope: Scope = field(metadata=Scope.type_metadata(1))
|
scope: Scope = field(metadata=Scope.type_metadata(1))
|
||||||
uid: int = field(metadata=_UINT64_BE_METADATA)
|
uid: int = field(metadata=_UINT64_BE_METADATA)
|
||||||
uid_counter: int = field(metadata=hci.metadata('>2'))
|
uid_counter: int = field(metadata=hci.metadata('>2'))
|
||||||
start_item: int = field(metadata=hci.metadata('>4'))
|
|
||||||
end_item: int = field(metadata=hci.metadata('>4'))
|
|
||||||
# When attributes is empty, all attributes will be requested.
|
# When attributes is empty, all attributes will be requested.
|
||||||
attributes: Sequence[MediaAttributeId] = field(
|
attributes: Sequence[MediaAttributeId] = field(
|
||||||
metadata=MediaAttributeId.type_metadata(1, list_begin=True, list_end=True)
|
metadata=MediaAttributeId.type_metadata(4, list_begin=True, list_end=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1512,7 +1517,9 @@ class PlaybackPositionChangedEvent(Event):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TrackChangedEvent(Event):
|
class TrackChangedEvent(Event):
|
||||||
event_id = EventId.TRACK_CHANGED
|
event_id = EventId.TRACK_CHANGED
|
||||||
identifier: bytes = field(metadata=hci.metadata('*'))
|
NO_TRACK = 0xFFFFFFFFFFFFFFFF
|
||||||
|
|
||||||
|
uid: int = field(metadata=_UINT64_BE_METADATA)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -1536,16 +1543,19 @@ class PlayerApplicationSettingChangedEvent(Event):
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
|
match self.attribute_id:
|
||||||
self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id)
|
case ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
|
||||||
elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
|
self.value_id = ApplicationSetting.EqualizerOnOffStatus(
|
||||||
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
|
self.value_id
|
||||||
elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
|
)
|
||||||
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
|
case ApplicationSetting.AttributeId.REPEAT_MODE:
|
||||||
elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
|
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
|
||||||
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
|
case ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
|
||||||
else:
|
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
|
||||||
self.value_id = ApplicationSetting.GenericValue(self.value_id)
|
case ApplicationSetting.AttributeId.SCAN_ON_OFF:
|
||||||
|
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
|
||||||
|
case _:
|
||||||
|
self.value_id = ApplicationSetting.GenericValue(self.value_id)
|
||||||
|
|
||||||
player_application_settings: Sequence[Setting] = field(
|
player_application_settings: Sequence[Setting] = field(
|
||||||
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
|
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
|
||||||
@@ -1619,6 +1629,8 @@ class Delegate:
|
|||||||
|
|
||||||
supported_events: list[EventId]
|
supported_events: list[EventId]
|
||||||
supported_company_ids: list[int]
|
supported_company_ids: list[int]
|
||||||
|
supported_player_app_settings: dict[ApplicationSetting.AttributeId, list[int]]
|
||||||
|
player_app_settings: dict[ApplicationSetting.AttributeId, int]
|
||||||
volume: int
|
volume: int
|
||||||
playback_status: PlayStatus
|
playback_status: PlayStatus
|
||||||
|
|
||||||
@@ -1626,11 +1638,23 @@ class Delegate:
|
|||||||
self,
|
self,
|
||||||
supported_events: Iterable[EventId] = (),
|
supported_events: Iterable[EventId] = (),
|
||||||
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
|
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
|
||||||
|
supported_player_app_settings: (
|
||||||
|
Mapping[ApplicationSetting.AttributeId, Sequence[int]] | None
|
||||||
|
) = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.supported_company_ids = list(supported_company_ids)
|
self.supported_company_ids = list(supported_company_ids)
|
||||||
self.supported_events = list(supported_events)
|
self.supported_events = list(supported_events)
|
||||||
self.volume = 0
|
self.volume = 0
|
||||||
self.playback_status = PlayStatus.STOPPED
|
self.playback_status = PlayStatus.STOPPED
|
||||||
|
self.supported_player_app_settings = (
|
||||||
|
{key: list(value) for key, value in supported_player_app_settings.items()}
|
||||||
|
if supported_player_app_settings
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
self.player_app_settings = {}
|
||||||
|
self.uid_counter = 0
|
||||||
|
self.addressed_player_id = 0
|
||||||
|
self.current_track_uid = TrackChangedEvent.NO_TRACK
|
||||||
|
|
||||||
async def get_supported_events(self) -> list[EventId]:
|
async def get_supported_events(self) -> list[EventId]:
|
||||||
return self.supported_events
|
return self.supported_events
|
||||||
@@ -1663,6 +1687,38 @@ class Delegate:
|
|||||||
async def get_playback_status(self) -> PlayStatus:
|
async def get_playback_status(self) -> PlayStatus:
|
||||||
return self.playback_status
|
return self.playback_status
|
||||||
|
|
||||||
|
async def get_supported_player_app_settings(
|
||||||
|
self,
|
||||||
|
) -> dict[ApplicationSetting.AttributeId, list[int]]:
|
||||||
|
return self.supported_player_app_settings
|
||||||
|
|
||||||
|
async def get_current_player_app_settings(
|
||||||
|
self,
|
||||||
|
) -> dict[ApplicationSetting.AttributeId, int]:
|
||||||
|
return self.player_app_settings
|
||||||
|
|
||||||
|
async def set_player_app_settings(
|
||||||
|
self, attribute: ApplicationSetting.AttributeId, value: int
|
||||||
|
) -> None:
|
||||||
|
self.player_app_settings[attribute] = value
|
||||||
|
|
||||||
|
async def play_item(self, scope: Scope, uid: int, uid_counter: int) -> None:
|
||||||
|
logger.debug(
|
||||||
|
"@@@ play_item: scope=%s, uid=%s, uid_counter=%s",
|
||||||
|
scope,
|
||||||
|
uid,
|
||||||
|
uid_counter,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_uid_counter(self) -> int:
|
||||||
|
return self.uid_counter
|
||||||
|
|
||||||
|
async def get_addressed_player_id(self) -> int:
|
||||||
|
return self.addressed_player_id
|
||||||
|
|
||||||
|
async def get_current_track_uid(self) -> int:
|
||||||
|
return self.current_track_uid
|
||||||
|
|
||||||
# TODO add other delegate methods
|
# TODO add other delegate methods
|
||||||
|
|
||||||
|
|
||||||
@@ -1910,6 +1966,51 @@ class Protocol(utils.EventEmitter):
|
|||||||
response = self._check_response(response_context, GetElementAttributesResponse)
|
response = self._check_response(response_context, GetElementAttributesResponse)
|
||||||
return list(response.attributes)
|
return list(response.attributes)
|
||||||
|
|
||||||
|
async def list_supported_player_app_settings(
|
||||||
|
self, attribute_ids: Sequence[ApplicationSetting.AttributeId] = ()
|
||||||
|
) -> dict[ApplicationSetting.AttributeId, list[int]]:
|
||||||
|
"""Get element attributes from the connected peer."""
|
||||||
|
response_context = await self.send_avrcp_command(
|
||||||
|
avc.CommandFrame.CommandType.STATUS,
|
||||||
|
ListPlayerApplicationSettingAttributesCommand(),
|
||||||
|
)
|
||||||
|
if not attribute_ids:
|
||||||
|
list_attribute_response = self._check_response(
|
||||||
|
response_context, ListPlayerApplicationSettingAttributesResponse
|
||||||
|
)
|
||||||
|
attribute_ids = list_attribute_response.attribute
|
||||||
|
|
||||||
|
supported_settings: dict[ApplicationSetting.AttributeId, list[int]] = {}
|
||||||
|
for attribute_id in attribute_ids:
|
||||||
|
response_context = await self.send_avrcp_command(
|
||||||
|
avc.CommandFrame.CommandType.STATUS,
|
||||||
|
ListPlayerApplicationSettingValuesCommand(attribute_id),
|
||||||
|
)
|
||||||
|
list_value_response = self._check_response(
|
||||||
|
response_context, ListPlayerApplicationSettingValuesResponse
|
||||||
|
)
|
||||||
|
supported_settings[attribute_id] = list(list_value_response.value)
|
||||||
|
|
||||||
|
return supported_settings
|
||||||
|
|
||||||
|
async def get_player_app_settings(
|
||||||
|
self, attribute_ids: Sequence[ApplicationSetting.AttributeId]
|
||||||
|
) -> dict[ApplicationSetting.AttributeId, int]:
|
||||||
|
"""Get element attributes from the connected peer."""
|
||||||
|
response_context = await self.send_avrcp_command(
|
||||||
|
avc.CommandFrame.CommandType.STATUS,
|
||||||
|
GetCurrentPlayerApplicationSettingValueCommand(attribute_ids),
|
||||||
|
)
|
||||||
|
response: GetCurrentPlayerApplicationSettingValueResponse = (
|
||||||
|
self._check_response(
|
||||||
|
response_context, GetCurrentPlayerApplicationSettingValueResponse
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
attribute_id: value
|
||||||
|
for attribute_id, value in zip(response.attribute, response.value)
|
||||||
|
}
|
||||||
|
|
||||||
async def monitor_events(
|
async def monitor_events(
|
||||||
self, event_id: EventId, playback_interval: int = 0
|
self, event_id: EventId, playback_interval: int = 0
|
||||||
) -> AsyncIterator[Event]:
|
) -> AsyncIterator[Event]:
|
||||||
@@ -1961,13 +2062,13 @@ class Protocol(utils.EventEmitter):
|
|||||||
|
|
||||||
async def monitor_track_changed(
|
async def monitor_track_changed(
|
||||||
self,
|
self,
|
||||||
) -> AsyncIterator[bytes]:
|
) -> AsyncIterator[int]:
|
||||||
"""Monitor Track changes from the connected peer."""
|
"""Monitor Track changes from the connected peer."""
|
||||||
async for event in self.monitor_events(EventId.TRACK_CHANGED, 0):
|
async for event in self.monitor_events(EventId.TRACK_CHANGED, 0):
|
||||||
if not isinstance(event, TrackChangedEvent):
|
if not isinstance(event, TrackChangedEvent):
|
||||||
logger.warning("unexpected event class")
|
logger.warning("unexpected event class")
|
||||||
continue
|
continue
|
||||||
yield event.identifier
|
yield event.uid
|
||||||
|
|
||||||
async def monitor_playback_position(
|
async def monitor_playback_position(
|
||||||
self, playback_interval: int
|
self, playback_interval: int
|
||||||
@@ -2060,11 +2161,9 @@ class Protocol(utils.EventEmitter):
|
|||||||
"""Notify the connected peer of a Playback Status change."""
|
"""Notify the connected peer of a Playback Status change."""
|
||||||
self.notify_event(PlaybackStatusChangedEvent(status))
|
self.notify_event(PlaybackStatusChangedEvent(status))
|
||||||
|
|
||||||
def notify_track_changed(self, identifier: bytes) -> None:
|
def notify_track_changed(self, uid: int) -> None:
|
||||||
"""Notify the connected peer of a Track change."""
|
"""Notify the connected peer of a Track change."""
|
||||||
if len(identifier) != 8:
|
self.notify_event(TrackChangedEvent(uid))
|
||||||
raise core.InvalidArgumentError("identifier must be 8 bytes")
|
|
||||||
self.notify_event(TrackChangedEvent(identifier))
|
|
||||||
|
|
||||||
def notify_playback_position_changed(self, position: int) -> None:
|
def notify_playback_position_changed(self, position: int) -> None:
|
||||||
"""Notify the connected peer of a Position change."""
|
"""Notify the connected peer of a Position change."""
|
||||||
@@ -2280,21 +2379,40 @@ class Protocol(utils.EventEmitter):
|
|||||||
):
|
):
|
||||||
# TODO: catch exceptions from delegates
|
# TODO: catch exceptions from delegates
|
||||||
command = Command.from_bytes(pdu_id, pdu)
|
command = Command.from_bytes(pdu_id, pdu)
|
||||||
if isinstance(command, GetCapabilitiesCommand):
|
match command:
|
||||||
self._on_get_capabilities_command(transaction_label, command)
|
case GetCapabilitiesCommand():
|
||||||
elif isinstance(command, SetAbsoluteVolumeCommand):
|
self._on_get_capabilities_command(transaction_label, command)
|
||||||
self._on_set_absolute_volume_command(transaction_label, command)
|
case SetAbsoluteVolumeCommand():
|
||||||
elif isinstance(command, RegisterNotificationCommand):
|
self._on_set_absolute_volume_command(transaction_label, command)
|
||||||
self._on_register_notification_command(transaction_label, command)
|
case RegisterNotificationCommand():
|
||||||
elif isinstance(command, GetPlayStatusCommand):
|
self._on_register_notification_command(transaction_label, command)
|
||||||
self._on_get_play_status_command(transaction_label, command)
|
case GetPlayStatusCommand():
|
||||||
else:
|
self._on_get_play_status_command(transaction_label, command)
|
||||||
# Not supported.
|
case ListPlayerApplicationSettingAttributesCommand():
|
||||||
# TODO: check that this is the right way to respond in this case.
|
self._on_list_player_application_setting_attributes_command(
|
||||||
logger.debug("unsupported PDU ID")
|
transaction_label, command
|
||||||
self.send_rejected_avrcp_response(
|
)
|
||||||
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
|
case ListPlayerApplicationSettingValuesCommand():
|
||||||
)
|
self._on_list_player_application_setting_values_command(
|
||||||
|
transaction_label, command
|
||||||
|
)
|
||||||
|
case SetPlayerApplicationSettingValueCommand():
|
||||||
|
self._on_set_player_application_setting_value_command(
|
||||||
|
transaction_label, command
|
||||||
|
)
|
||||||
|
case GetCurrentPlayerApplicationSettingValueCommand():
|
||||||
|
self._on_get_current_player_application_setting_value_command(
|
||||||
|
transaction_label, command
|
||||||
|
)
|
||||||
|
case PlayItemCommand():
|
||||||
|
self._on_play_item_command(transaction_label, command)
|
||||||
|
case _:
|
||||||
|
# Not supported.
|
||||||
|
# TODO: check that this is the right way to respond in this case.
|
||||||
|
logger.debug("unsupported PDU ID")
|
||||||
|
self.send_rejected_avrcp_response(
|
||||||
|
transaction_label, pdu_id, StatusCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("unsupported command type")
|
logger.debug("unsupported command type")
|
||||||
self.send_rejected_avrcp_response(
|
self.send_rejected_avrcp_response(
|
||||||
@@ -2322,26 +2440,29 @@ class Protocol(utils.EventEmitter):
|
|||||||
# is Ok, but if/when more responses are supported, a lookup mechanism would be
|
# is Ok, but if/when more responses are supported, a lookup mechanism would be
|
||||||
# more appropriate.
|
# more appropriate.
|
||||||
response: Response | None = None
|
response: Response | None = None
|
||||||
if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
|
match response_code:
|
||||||
response = RejectedResponse(pdu_id=pdu_id, status_code=StatusCode(pdu[0]))
|
case avc.ResponseFrame.ResponseCode.REJECTED:
|
||||||
elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
|
response = RejectedResponse(
|
||||||
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
|
pdu_id=pdu_id, status_code=StatusCode(pdu[0])
|
||||||
elif response_code in (
|
)
|
||||||
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
case avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
|
||||||
avc.ResponseFrame.ResponseCode.INTERIM,
|
response = NotImplementedResponse(pdu_id=pdu_id, parameters=pdu)
|
||||||
avc.ResponseFrame.ResponseCode.CHANGED,
|
case (
|
||||||
avc.ResponseFrame.ResponseCode.ACCEPTED,
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE
|
||||||
):
|
| avc.ResponseFrame.ResponseCode.INTERIM
|
||||||
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
|
| avc.ResponseFrame.ResponseCode.CHANGED
|
||||||
else:
|
| avc.ResponseFrame.ResponseCode.ACCEPTED
|
||||||
logger.debug("unexpected response code")
|
):
|
||||||
pending_command.response.set_exception(
|
response = Response.from_bytes(pdu=pdu, pdu_id=PduId(pdu_id))
|
||||||
core.ProtocolError(
|
case _:
|
||||||
error_code=None,
|
logger.debug("unexpected response code")
|
||||||
error_namespace="avrcp",
|
pending_command.response.set_exception(
|
||||||
details="unexpected response code",
|
core.ProtocolError(
|
||||||
|
error_code=None,
|
||||||
|
error_namespace="avrcp",
|
||||||
|
details="unexpected response code",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
self.recycle_pending_command(pending_command)
|
self.recycle_pending_command(pending_command)
|
||||||
@@ -2512,22 +2633,18 @@ class Protocol(utils.EventEmitter):
|
|||||||
|
|
||||||
async def get_supported_events() -> None:
|
async def get_supported_events() -> None:
|
||||||
capabilities: Sequence[bytes | SupportsBytes]
|
capabilities: Sequence[bytes | SupportsBytes]
|
||||||
if (
|
match command.capability_id:
|
||||||
command.capability_id
|
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
|
||||||
== GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
|
capabilities = await self.delegate.get_supported_events()
|
||||||
):
|
case GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED.COMPANY_ID:
|
||||||
capabilities = await self.delegate.get_supported_events()
|
company_ids = await self.delegate.get_supported_company_ids()
|
||||||
elif (
|
capabilities = [
|
||||||
command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID
|
company_id.to_bytes(3, 'big') for company_id in company_ids
|
||||||
):
|
]
|
||||||
company_ids = await self.delegate.get_supported_company_ids()
|
case _:
|
||||||
capabilities = [
|
raise core.InvalidArgumentError(
|
||||||
company_id.to_bytes(3, 'big') for company_id in company_ids
|
f"Unsupported capability: {command.capability_id}"
|
||||||
]
|
)
|
||||||
else:
|
|
||||||
raise core.InvalidArgumentError(
|
|
||||||
f"Unsupported capability: {command.capability_id}"
|
|
||||||
)
|
|
||||||
self.send_avrcp_response(
|
self.send_avrcp_response(
|
||||||
transaction_label,
|
transaction_label,
|
||||||
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
@@ -2572,6 +2689,121 @@ class Protocol(utils.EventEmitter):
|
|||||||
|
|
||||||
self._delegate_command(transaction_label, command, get_playback_status())
|
self._delegate_command(transaction_label, command, get_playback_status())
|
||||||
|
|
||||||
|
def _on_list_player_application_setting_attributes_command(
|
||||||
|
self,
|
||||||
|
transaction_label: int,
|
||||||
|
command: ListPlayerApplicationSettingAttributesCommand,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("<<< AVRCP command PDU: %s", command)
|
||||||
|
|
||||||
|
async def get_supported_player_app_settings() -> None:
|
||||||
|
supported_settings = await self.delegate.get_supported_player_app_settings()
|
||||||
|
self.send_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
|
ListPlayerApplicationSettingAttributesResponse(
|
||||||
|
list(supported_settings.keys())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._delegate_command(
|
||||||
|
transaction_label, command, get_supported_player_app_settings()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_list_player_application_setting_values_command(
|
||||||
|
self,
|
||||||
|
transaction_label: int,
|
||||||
|
command: ListPlayerApplicationSettingValuesCommand,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("<<< AVRCP command PDU: %s", command)
|
||||||
|
|
||||||
|
async def get_supported_player_app_settings() -> None:
|
||||||
|
supported_settings = await self.delegate.get_supported_player_app_settings()
|
||||||
|
self.send_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
|
ListPlayerApplicationSettingValuesResponse(
|
||||||
|
supported_settings.get(command.attribute, [])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._delegate_command(
|
||||||
|
transaction_label, command, get_supported_player_app_settings()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_get_current_player_application_setting_value_command(
|
||||||
|
self,
|
||||||
|
transaction_label: int,
|
||||||
|
command: GetCurrentPlayerApplicationSettingValueCommand,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("<<< AVRCP command PDU: %s", command)
|
||||||
|
|
||||||
|
async def get_supported_player_app_settings() -> None:
|
||||||
|
current_settings = await self.delegate.get_current_player_app_settings()
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
attribute in current_settings for attribute in command.attribute
|
||||||
|
):
|
||||||
|
self.send_not_implemented_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
PduId.GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
|
GetCurrentPlayerApplicationSettingValueResponse(
|
||||||
|
attribute=command.attribute,
|
||||||
|
value=[
|
||||||
|
current_settings[attribute] for attribute in command.attribute
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._delegate_command(
|
||||||
|
transaction_label, command, get_supported_player_app_settings()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_set_player_application_setting_value_command(
|
||||||
|
self,
|
||||||
|
transaction_label: int,
|
||||||
|
command: SetPlayerApplicationSettingValueCommand,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("<<< AVRCP command PDU: %s", command)
|
||||||
|
|
||||||
|
async def set_player_app_settings() -> None:
|
||||||
|
for attribute, value in zip(command.attribute, command.value):
|
||||||
|
await self.delegate.set_player_app_settings(attribute, value)
|
||||||
|
|
||||||
|
self.send_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
|
SetPlayerApplicationSettingValueResponse(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._delegate_command(transaction_label, command, set_player_app_settings())
|
||||||
|
|
||||||
|
def _on_play_item_command(
|
||||||
|
self,
|
||||||
|
transaction_label: int,
|
||||||
|
command: PlayItemCommand,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("<<< AVRCP command PDU: %s", command)
|
||||||
|
|
||||||
|
async def play_item() -> None:
|
||||||
|
await self.delegate.play_item(
|
||||||
|
scope=command.scope, uid=command.uid, uid_counter=command.uid_counter
|
||||||
|
)
|
||||||
|
|
||||||
|
self.send_avrcp_response(
|
||||||
|
transaction_label,
|
||||||
|
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
|
||||||
|
PlayItemResponse(status=StatusCode.OPERATION_COMPLETED),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._delegate_command(transaction_label, command, play_item())
|
||||||
|
|
||||||
def _on_register_notification_command(
|
def _on_register_notification_command(
|
||||||
self, transaction_label: int, command: RegisterNotificationCommand
|
self, transaction_label: int, command: RegisterNotificationCommand
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -2587,26 +2819,51 @@ class Protocol(utils.EventEmitter):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
response: Response
|
event: Event
|
||||||
if command.event_id == EventId.VOLUME_CHANGED:
|
match command.event_id:
|
||||||
volume = await self.delegate.get_absolute_volume()
|
case EventId.VOLUME_CHANGED:
|
||||||
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
|
volume = await self.delegate.get_absolute_volume()
|
||||||
elif command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
|
event = VolumeChangedEvent(volume)
|
||||||
playback_status = await self.delegate.get_playback_status()
|
case EventId.PLAYBACK_STATUS_CHANGED:
|
||||||
response = RegisterNotificationResponse(
|
playback_status = await self.delegate.get_playback_status()
|
||||||
PlaybackStatusChangedEvent(play_status=playback_status)
|
event = PlaybackStatusChangedEvent(play_status=playback_status)
|
||||||
)
|
case EventId.NOW_PLAYING_CONTENT_CHANGED:
|
||||||
elif command.event_id == EventId.NOW_PLAYING_CONTENT_CHANGED:
|
event = NowPlayingContentChangedEvent()
|
||||||
playback_status = await self.delegate.get_playback_status()
|
case EventId.PLAYER_APPLICATION_SETTING_CHANGED:
|
||||||
response = RegisterNotificationResponse(NowPlayingContentChangedEvent())
|
settings = await self.delegate.get_current_player_app_settings()
|
||||||
else:
|
event = PlayerApplicationSettingChangedEvent(
|
||||||
logger.warning("Event supported but not handled %s", command.event_id)
|
[
|
||||||
return
|
PlayerApplicationSettingChangedEvent.Setting(
|
||||||
|
attribute, value # type: ignore
|
||||||
|
)
|
||||||
|
for attribute, value in settings.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
case EventId.AVAILABLE_PLAYERS_CHANGED:
|
||||||
|
event = AvailablePlayersChangedEvent()
|
||||||
|
case EventId.ADDRESSED_PLAYER_CHANGED:
|
||||||
|
event = AddressedPlayerChangedEvent(
|
||||||
|
AddressedPlayerChangedEvent.Player(
|
||||||
|
player_id=await self.delegate.get_addressed_player_id(),
|
||||||
|
uid_counter=await self.delegate.get_uid_counter(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case EventId.UIDS_CHANGED:
|
||||||
|
event = UidsChangedEvent(await self.delegate.get_uid_counter())
|
||||||
|
case EventId.TRACK_CHANGED:
|
||||||
|
event = TrackChangedEvent(
|
||||||
|
await self.delegate.get_current_track_uid()
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
logger.warning(
|
||||||
|
"Event supported but not handled %s", command.event_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self.send_avrcp_response(
|
self.send_avrcp_response(
|
||||||
transaction_label,
|
transaction_label,
|
||||||
avc.ResponseFrame.ResponseCode.INTERIM,
|
avc.ResponseFrame.ResponseCode.INTERIM,
|
||||||
response,
|
RegisterNotificationResponse(event),
|
||||||
)
|
)
|
||||||
self._register_notification_listener(transaction_label, command)
|
self._register_notification_listener(transaction_label, command)
|
||||||
|
|
||||||
|
|||||||
+717
-570
File diff suppressed because it is too large
Load Diff
+82
-74
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
|
import functools
|
||||||
import struct
|
import struct
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -273,6 +274,18 @@ class UUID:
|
|||||||
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
|
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
|
||||||
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
|
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def uuid_128_bytes(self) -> bytes:
|
||||||
|
match len(self.uuid_bytes):
|
||||||
|
case 2:
|
||||||
|
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
|
||||||
|
case 4:
|
||||||
|
return self.BASE_UUID + self.uuid_bytes
|
||||||
|
case 16:
|
||||||
|
return self.uuid_bytes
|
||||||
|
case _:
|
||||||
|
assert False, "unreachable"
|
||||||
|
|
||||||
def to_bytes(self, force_128: bool = False) -> bytes:
|
def to_bytes(self, force_128: bool = False) -> bytes:
|
||||||
'''
|
'''
|
||||||
Serialize UUID in little-endian byte-order
|
Serialize UUID in little-endian byte-order
|
||||||
@@ -280,14 +293,7 @@ class UUID:
|
|||||||
if not force_128:
|
if not force_128:
|
||||||
return self.uuid_bytes
|
return self.uuid_bytes
|
||||||
|
|
||||||
if len(self.uuid_bytes) == 2:
|
return self.uuid_128_bytes
|
||||||
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
|
|
||||||
elif len(self.uuid_bytes) == 4:
|
|
||||||
return self.BASE_UUID + self.uuid_bytes
|
|
||||||
elif len(self.uuid_bytes) == 16:
|
|
||||||
return self.uuid_bytes
|
|
||||||
else:
|
|
||||||
assert False, "unreachable"
|
|
||||||
|
|
||||||
def to_pdu_bytes(self) -> bytes:
|
def to_pdu_bytes(self) -> bytes:
|
||||||
'''
|
'''
|
||||||
@@ -317,7 +323,7 @@ class UUID:
|
|||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, UUID):
|
if isinstance(other, UUID):
|
||||||
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
|
return self.uuid_128_bytes == other.uuid_128_bytes
|
||||||
|
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
return UUID(other) == self
|
return UUID(other) == self
|
||||||
@@ -325,7 +331,7 @@ class UUID:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self.uuid_bytes)
|
return hash(self.uuid_128_bytes)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
result = self.to_hex_str(separator='-')
|
result = self.to_hex_str(separator='-')
|
||||||
@@ -1769,66 +1775,71 @@ class AdvertisingData:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
|
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
|
||||||
if ad_type == AdvertisingData.FLAGS:
|
match ad_type:
|
||||||
ad_type_str = 'Flags'
|
case AdvertisingData.FLAGS:
|
||||||
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
|
ad_type_str = 'Flags'
|
||||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
|
||||||
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
|
case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||||
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
|
case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||||
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
|
case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||||
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
|
case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||||
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
|
case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||||
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
|
case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
|
||||||
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
|
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||||
ad_type_str = 'Service Data'
|
case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
|
||||||
uuid = UUID.from_bytes(ad_data[:2])
|
ad_type_str = 'Service Data'
|
||||||
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
|
uuid = UUID.from_bytes(ad_data[:2])
|
||||||
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
|
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
|
||||||
ad_type_str = 'Service Data'
|
case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
|
||||||
uuid = UUID.from_bytes(ad_data[:4])
|
ad_type_str = 'Service Data'
|
||||||
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
|
uuid = UUID.from_bytes(ad_data[:4])
|
||||||
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
|
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
|
||||||
ad_type_str = 'Service Data'
|
case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
|
||||||
uuid = UUID.from_bytes(ad_data[:16])
|
ad_type_str = 'Service Data'
|
||||||
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
|
uuid = UUID.from_bytes(ad_data[:16])
|
||||||
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME:
|
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
|
||||||
ad_type_str = 'Shortened Local Name'
|
case AdvertisingData.SHORTENED_LOCAL_NAME:
|
||||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
ad_type_str = 'Shortened Local Name'
|
||||||
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
|
|
||||||
ad_type_str = 'Complete Local Name'
|
|
||||||
try:
|
|
||||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||||
except UnicodeDecodeError:
|
case AdvertisingData.COMPLETE_LOCAL_NAME:
|
||||||
|
ad_type_str = 'Complete Local Name'
|
||||||
|
try:
|
||||||
|
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
ad_data_str = ad_data.hex()
|
||||||
|
case AdvertisingData.TX_POWER_LEVEL:
|
||||||
|
ad_type_str = 'TX Power Level'
|
||||||
|
ad_data_str = str(ad_data[0])
|
||||||
|
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
||||||
|
ad_type_str = 'Manufacturer Specific Data'
|
||||||
|
company_id = struct.unpack_from('<H', ad_data, 0)[0]
|
||||||
|
company_name = COMPANY_IDENTIFIERS.get(
|
||||||
|
company_id, f'0x{company_id:04X}'
|
||||||
|
)
|
||||||
|
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
|
||||||
|
case AdvertisingData.APPEARANCE:
|
||||||
|
ad_type_str = 'Appearance'
|
||||||
|
appearance = Appearance.from_int(
|
||||||
|
struct.unpack_from('<H', ad_data, 0)[0]
|
||||||
|
)
|
||||||
|
ad_data_str = str(appearance)
|
||||||
|
case AdvertisingData.BROADCAST_NAME:
|
||||||
|
ad_type_str = 'Broadcast Name'
|
||||||
|
ad_data_str = ad_data.decode('utf-8')
|
||||||
|
case _:
|
||||||
|
ad_type_str = AdvertisingData.Type(ad_type).name
|
||||||
ad_data_str = ad_data.hex()
|
ad_data_str = ad_data.hex()
|
||||||
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
|
|
||||||
ad_type_str = 'TX Power Level'
|
|
||||||
ad_data_str = str(ad_data[0])
|
|
||||||
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
|
||||||
ad_type_str = 'Manufacturer Specific Data'
|
|
||||||
company_id = struct.unpack_from('<H', ad_data, 0)[0]
|
|
||||||
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
|
|
||||||
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
|
|
||||||
elif ad_type == AdvertisingData.APPEARANCE:
|
|
||||||
ad_type_str = 'Appearance'
|
|
||||||
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
|
|
||||||
ad_data_str = str(appearance)
|
|
||||||
elif ad_type == AdvertisingData.BROADCAST_NAME:
|
|
||||||
ad_type_str = 'Broadcast Name'
|
|
||||||
ad_data_str = ad_data.decode('utf-8')
|
|
||||||
else:
|
|
||||||
ad_type_str = AdvertisingData.Type(ad_type).name
|
|
||||||
ad_data_str = ad_data.hex()
|
|
||||||
|
|
||||||
return f'[{ad_type_str}]: {ad_data_str}'
|
return f'[{ad_type_str}]: {ad_data_str}'
|
||||||
|
|
||||||
@@ -2105,13 +2116,10 @@ class AdvertisingData:
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Connection PHY
|
# Connection PHY
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
@dataclasses.dataclass
|
||||||
class ConnectionPHY:
|
class ConnectionPHY:
|
||||||
def __init__(self, tx_phy, rx_phy):
|
tx_phy: int
|
||||||
self.tx_phy = tx_phy
|
rx_phy: int
|
||||||
self.rx_phy = rx_phy
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
+92
-11
@@ -1837,6 +1837,7 @@ class Connection(utils.CompositeEventEmitter):
|
|||||||
self.pairing_peer_io_capability = None
|
self.pairing_peer_io_capability = None
|
||||||
self.pairing_peer_authentication_requirements = None
|
self.pairing_peer_authentication_requirements = None
|
||||||
self.peer_le_features = hci.LeFeatureMask(0)
|
self.peer_le_features = hci.LeFeatureMask(0)
|
||||||
|
self.peer_classic_features = hci.LmpFeatureMask(0)
|
||||||
self.cs_configs = {}
|
self.cs_configs = {}
|
||||||
self.cs_procedures = {}
|
self.cs_procedures = {}
|
||||||
|
|
||||||
@@ -2054,6 +2055,15 @@ class Connection(utils.CompositeEventEmitter):
|
|||||||
self.peer_le_features = await self.device.get_remote_le_features(self)
|
self.peer_le_features = await self.device.get_remote_le_features(self)
|
||||||
return self.peer_le_features
|
return self.peer_le_features
|
||||||
|
|
||||||
|
async def get_remote_classic_features(self) -> hci.LmpFeatureMask:
|
||||||
|
"""[Classic Only] Reads remote LMP supported features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LMP features supported by the remote device.
|
||||||
|
"""
|
||||||
|
self.peer_classic_features = await self.device.get_remote_classic_features(self)
|
||||||
|
return self.peer_classic_features
|
||||||
|
|
||||||
def on_att_mtu_update(self, mtu: int):
|
def on_att_mtu_update(self, mtu: int):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
|
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
|
||||||
@@ -2149,6 +2159,7 @@ class DeviceConfiguration:
|
|||||||
)
|
)
|
||||||
eatt_enabled: bool = False
|
eatt_enabled: bool = False
|
||||||
gatt_services: list[dict[str, Any]] = field(init=False)
|
gatt_services: list[dict[str, Any]] = field(init=False)
|
||||||
|
smp_debug_mode: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.gatt_services = []
|
self.gatt_services = []
|
||||||
@@ -2332,6 +2343,9 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
_pending_cis: dict[int, tuple[int, int]]
|
_pending_cis: dict[int, tuple[int, int]]
|
||||||
gatt_service: gatt_service.GenericAttributeProfileService | None = None
|
gatt_service: gatt_service.GenericAttributeProfileService | None = None
|
||||||
keystore: KeyStore | None = None
|
keystore: KeyStore | None = None
|
||||||
|
inquiry_response: bytes | None = None
|
||||||
|
address_resolver: smp.AddressResolver | None = None
|
||||||
|
connect_own_address_type: hci.OwnAddressType | None = None
|
||||||
|
|
||||||
EVENT_ADVERTISEMENT = "advertisement"
|
EVENT_ADVERTISEMENT = "advertisement"
|
||||||
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
|
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
|
||||||
@@ -2450,17 +2464,12 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
self.bis_links = {}
|
self.bis_links = {}
|
||||||
self.big_syncs = {}
|
self.big_syncs = {}
|
||||||
self.classic_enabled = False
|
self.classic_enabled = False
|
||||||
self.inquiry_response = None
|
|
||||||
self.address_resolver = None
|
|
||||||
self.classic_pending_accepts = {
|
self.classic_pending_accepts = {
|
||||||
hci.Address.ANY: []
|
hci.Address.ANY: []
|
||||||
} # Futures, by BD address OR [Futures] for hci.Address.ANY
|
} # Futures, by BD address OR [Futures] for hci.Address.ANY
|
||||||
|
|
||||||
self._cis_lock = asyncio.Lock()
|
self._cis_lock = asyncio.Lock()
|
||||||
|
|
||||||
# Own address type cache
|
|
||||||
self.connect_own_address_type = None
|
|
||||||
|
|
||||||
self.name = config.name
|
self.name = config.name
|
||||||
self.public_address = hci.Address.ANY
|
self.public_address = hci.Address.ANY
|
||||||
self.random_address = config.address
|
self.random_address = config.address
|
||||||
@@ -2561,6 +2570,7 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.smp_manager.debug_mode = self.config.smp_debug_mode
|
||||||
|
|
||||||
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
||||||
|
|
||||||
@@ -5281,6 +5291,77 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
)
|
)
|
||||||
return await read_feature_future
|
return await read_feature_future
|
||||||
|
|
||||||
|
async def get_remote_classic_features(
|
||||||
|
self, connection: Connection
|
||||||
|
) -> hci.LmpFeatureMask:
|
||||||
|
"""[Classic Only] Reads remote LE supported features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handle: connection handle to read LMP features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LMP features supported by the remote device.
|
||||||
|
"""
|
||||||
|
with closing(utils.EventWatcher()) as watcher:
|
||||||
|
read_feature_future: asyncio.Future[tuple[int, int]] = (
|
||||||
|
asyncio.get_running_loop().create_future()
|
||||||
|
)
|
||||||
|
read_features = hci.LmpFeatureMask(0)
|
||||||
|
current_page_number = 0
|
||||||
|
|
||||||
|
@watcher.on(self.host, 'classic_remote_features')
|
||||||
|
def on_classic_remote_features(
|
||||||
|
handle: int,
|
||||||
|
status: int,
|
||||||
|
features: int,
|
||||||
|
page_number: int,
|
||||||
|
max_page_number: int,
|
||||||
|
) -> None:
|
||||||
|
if handle != connection.handle:
|
||||||
|
logger.warning(
|
||||||
|
"Received classic_remote_features for wrong handle, expected=0x%04X, got=0x%04X",
|
||||||
|
connection.handle,
|
||||||
|
handle,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if page_number != current_page_number:
|
||||||
|
logger.warning(
|
||||||
|
"Received classic_remote_features for wrong page, expected=%d, got=%d",
|
||||||
|
current_page_number,
|
||||||
|
page_number,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if status == hci.HCI_ErrorCode.SUCCESS:
|
||||||
|
read_feature_future.set_result((features, max_page_number))
|
||||||
|
else:
|
||||||
|
read_feature_future.set_exception(hci.HCI_Error(status))
|
||||||
|
|
||||||
|
await self.send_async_command(
|
||||||
|
hci.HCI_Read_Remote_Supported_Features_Command(
|
||||||
|
connection_handle=connection.handle
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_features, max_page_number = await read_feature_future
|
||||||
|
read_features |= new_features
|
||||||
|
if not (read_features & hci.LmpFeatureMask.EXTENDED_FEATURES):
|
||||||
|
return read_features
|
||||||
|
|
||||||
|
while current_page_number <= max_page_number:
|
||||||
|
read_feature_future = asyncio.get_running_loop().create_future()
|
||||||
|
await self.send_async_command(
|
||||||
|
hci.HCI_Read_Remote_Extended_Features_Command(
|
||||||
|
connection_handle=connection.handle,
|
||||||
|
page_number=current_page_number,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_features, max_page_number = await read_feature_future
|
||||||
|
read_features |= new_features << (current_page_number * 64)
|
||||||
|
current_page_number += 1
|
||||||
|
|
||||||
|
return read_features
|
||||||
|
|
||||||
@utils.experimental('Only for testing.')
|
@utils.experimental('Only for testing.')
|
||||||
async def get_remote_cs_capabilities(
|
async def get_remote_cs_capabilities(
|
||||||
self, connection: Connection
|
self, connection: Connection
|
||||||
@@ -5535,8 +5616,8 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
async def notify_subscriber(
|
async def notify_subscriber(
|
||||||
self,
|
self,
|
||||||
connection: Connection,
|
connection: Connection,
|
||||||
attribute: Attribute,
|
attribute: Attribute[_T],
|
||||||
value: Any | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -5555,7 +5636,7 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def notify_subscribers(
|
async def notify_subscribers(
|
||||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send a notification to all the subscribers of an attribute.
|
Send a notification to all the subscribers of an attribute.
|
||||||
@@ -5574,8 +5655,8 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
async def indicate_subscriber(
|
async def indicate_subscriber(
|
||||||
self,
|
self,
|
||||||
connection: Connection,
|
connection: Connection,
|
||||||
attribute: Attribute,
|
attribute: Attribute[_T],
|
||||||
value: Any | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -5596,7 +5677,7 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def indicate_subscribers(
|
async def indicate_subscribers(
|
||||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send an indication to all the subscribers of an attribute.
|
Send an indication to all the subscribers of an attribute.
|
||||||
|
|||||||
+44
-43
@@ -201,50 +201,51 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
|
|||||||
value = data[2 : 2 + value_length]
|
value = data[2 : 2 + value_length]
|
||||||
typed_value: Any
|
typed_value: Any
|
||||||
|
|
||||||
if value_type == ValueType.END:
|
match value_type:
|
||||||
break
|
case ValueType.END:
|
||||||
|
break
|
||||||
|
|
||||||
if value_type in (ValueType.CNVI, ValueType.CNVR):
|
case ValueType.CNVI | ValueType.CNVR:
|
||||||
(v,) = struct.unpack("<I", value)
|
(v,) = struct.unpack("<I", value)
|
||||||
typed_value = (
|
typed_value = (
|
||||||
(((v >> 0) & 0xF) << 12)
|
(((v >> 0) & 0xF) << 12)
|
||||||
| (((v >> 4) & 0xF) << 0)
|
| (((v >> 4) & 0xF) << 0)
|
||||||
| (((v >> 8) & 0xF) << 4)
|
| (((v >> 8) & 0xF) << 4)
|
||||||
| (((v >> 24) & 0xF) << 8)
|
| (((v >> 24) & 0xF) << 8)
|
||||||
)
|
)
|
||||||
elif value_type == ValueType.HARDWARE_INFO:
|
case ValueType.HARDWARE_INFO:
|
||||||
(v,) = struct.unpack("<I", value)
|
(v,) = struct.unpack("<I", value)
|
||||||
typed_value = HardwareInfo(
|
typed_value = HardwareInfo(
|
||||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||||
)
|
)
|
||||||
elif value_type in (
|
case (
|
||||||
ValueType.USB_VENDOR_ID,
|
ValueType.USB_VENDOR_ID
|
||||||
ValueType.USB_PRODUCT_ID,
|
| ValueType.USB_PRODUCT_ID
|
||||||
ValueType.DEVICE_REVISION,
|
| ValueType.DEVICE_REVISION
|
||||||
):
|
):
|
||||||
(typed_value,) = struct.unpack("<H", value)
|
(typed_value,) = struct.unpack("<H", value)
|
||||||
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
|
case ValueType.CURRENT_MODE_OF_OPERATION:
|
||||||
typed_value = ModeOfOperation(value[0])
|
typed_value = ModeOfOperation(value[0])
|
||||||
elif value_type in (
|
case (
|
||||||
ValueType.BUILD_TYPE,
|
ValueType.BUILD_TYPE
|
||||||
ValueType.BUILD_NUMBER,
|
| ValueType.BUILD_NUMBER
|
||||||
ValueType.SECURE_BOOT,
|
| ValueType.SECURE_BOOT
|
||||||
ValueType.OTP_LOCK,
|
| ValueType.OTP_LOCK
|
||||||
ValueType.API_LOCK,
|
| ValueType.API_LOCK
|
||||||
ValueType.DEBUG_LOCK,
|
| ValueType.DEBUG_LOCK
|
||||||
ValueType.SECURE_BOOT_ENGINE_TYPE,
|
| ValueType.SECURE_BOOT_ENGINE_TYPE
|
||||||
):
|
):
|
||||||
typed_value = value[0]
|
typed_value = value[0]
|
||||||
elif value_type == ValueType.TIMESTAMP:
|
case ValueType.TIMESTAMP:
|
||||||
typed_value = Timestamp(value[0], value[1])
|
typed_value = Timestamp(value[0], value[1])
|
||||||
elif value_type == ValueType.FIRMWARE_BUILD:
|
case ValueType.FIRMWARE_BUILD:
|
||||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||||
elif value_type == ValueType.BLUETOOTH_ADDRESS:
|
case ValueType.BLUETOOTH_ADDRESS:
|
||||||
typed_value = hci.Address(
|
typed_value = hci.Address(
|
||||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||||
)
|
)
|
||||||
else:
|
case _:
|
||||||
typed_value = value
|
typed_value = value
|
||||||
|
|
||||||
result.append((value_type, typed_value))
|
result.append((value_type, typed_value))
|
||||||
data = data[2 + value_length :]
|
data = data[2 + value_length :]
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
# Copyright 2021-2022 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Imports
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
import logging
|
|
||||||
import struct
|
|
||||||
|
|
||||||
from bumble.gatt import (
|
|
||||||
GATT_APPEARANCE_CHARACTERISTIC,
|
|
||||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
|
||||||
GATT_GENERIC_ACCESS_SERVICE,
|
|
||||||
Characteristic,
|
|
||||||
Service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Logging
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Classes
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
class GenericAccessService(Service):
|
|
||||||
def __init__(self, device_name, appearance=(0, 0)):
|
|
||||||
device_name_characteristic = Characteristic(
|
|
||||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
|
||||||
Characteristic.Properties.READ,
|
|
||||||
Characteristic.READABLE,
|
|
||||||
device_name.encode('utf-8')[:248],
|
|
||||||
)
|
|
||||||
|
|
||||||
appearance_characteristic = Characteristic(
|
|
||||||
GATT_APPEARANCE_CHARACTERISTIC,
|
|
||||||
Characteristic.Properties.READ,
|
|
||||||
Characteristic.READABLE,
|
|
||||||
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
GATT_GENERIC_ACCESS_SERVICE,
|
|
||||||
[device_name_characteristic, appearance_characteristic],
|
|
||||||
)
|
|
||||||
+24
-22
@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
|||||||
# Helpers
|
# Helpers
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|
||||||
def _bearer_id(bearer: att.Bearer) -> str:
|
def _bearer_id(bearer: att.Bearer) -> str:
|
||||||
if att.is_enhanced_bearer(bearer):
|
if att.is_enhanced_bearer(bearer):
|
||||||
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def notify_subscriber(
|
async def notify_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if att.is_enhanced_bearer(bearer) or force:
|
if att.is_enhanced_bearer(bearer) or force:
|
||||||
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _notify_single_subscriber(
|
async def _notify_single_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None,
|
value: _T | None,
|
||||||
force: bool,
|
force: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get or encode the value
|
# Get or encode the value
|
||||||
value = (
|
value_as_bytes = (
|
||||||
await attribute.read_value(bearer)
|
await attribute.read_value(bearer)
|
||||||
if value is None
|
if value is None
|
||||||
else attribute.encode_value(value)
|
else attribute.encode_value(value)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
if len(value) > bearer.att_mtu - 3:
|
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||||
value = value[: bearer.att_mtu - 3]
|
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||||
|
|
||||||
# Notify
|
# Notify
|
||||||
notification = att.ATT_Handle_Value_Notification(
|
notification = att.ATT_Handle_Value_Notification(
|
||||||
attribute_handle=attribute.handle, attribute_value=value
|
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||||
)
|
)
|
||||||
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
||||||
self.send_gatt_pdu(bearer, bytes(notification))
|
self.send_gatt_pdu(bearer, bytes(notification))
|
||||||
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def indicate_subscriber(
|
async def indicate_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if att.is_enhanced_bearer(bearer) or force:
|
if att.is_enhanced_bearer(bearer) or force:
|
||||||
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _indicate_single_bearer(
|
async def _indicate_single_bearer(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None,
|
value: _T | None,
|
||||||
force: bool,
|
force: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get or encode the value
|
# Get or encode the value
|
||||||
value = (
|
value_as_bytes = (
|
||||||
await attribute.read_value(bearer)
|
await attribute.read_value(bearer)
|
||||||
if value is None
|
if value is None
|
||||||
else attribute.encode_value(value)
|
else attribute.encode_value(value)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
if len(value) > bearer.att_mtu - 3:
|
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||||
value = value[: bearer.att_mtu - 3]
|
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||||
|
|
||||||
# Indicate
|
# Indicate
|
||||||
indication = att.ATT_Handle_Value_Indication(
|
indication = att.ATT_Handle_Value_Indication(
|
||||||
attribute_handle=attribute.handle, attribute_value=value
|
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||||
)
|
)
|
||||||
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
||||||
|
|
||||||
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _notify_or_indicate_subscribers(
|
async def _notify_or_indicate_subscribers(
|
||||||
self,
|
self,
|
||||||
indicate: bool,
|
indicate: bool,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Get all the bearers for which there's at least one subscription
|
# Get all the bearers for which there's at least one subscription
|
||||||
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
|
|||||||
|
|
||||||
async def notify_subscribers(
|
async def notify_subscribers(
|
||||||
self,
|
self,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
return await self._notify_or_indicate_subscribers(
|
return await self._notify_or_indicate_subscribers(
|
||||||
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
|
|||||||
|
|
||||||
async def indicate_subscribers(
|
async def indicate_subscribers(
|
||||||
self,
|
self,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||||
|
|||||||
+88
-109
@@ -31,6 +31,7 @@ from typing import (
|
|||||||
ClassVar,
|
ClassVar,
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
|
SupportsBytes,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@@ -247,28 +248,6 @@ HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
|
|||||||
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
|
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
|
||||||
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
|
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
|
||||||
|
|
||||||
HCI_VERSION_NAMES = {
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_1_1: 'HCI_VERSION_BLUETOOTH_CORE_1_1',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_1_2: 'HCI_VERSION_BLUETOOTH_CORE_1_2',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_0_EDR',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_1_EDR',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_3_0_HS: 'HCI_VERSION_BLUETOOTH_CORE_3_0_HS',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_4_0: 'HCI_VERSION_BLUETOOTH_CORE_4_0',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_4_1: 'HCI_VERSION_BLUETOOTH_CORE_4_1',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_4_2: 'HCI_VERSION_BLUETOOTH_CORE_4_2',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
|
|
||||||
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
|
|
||||||
}
|
|
||||||
|
|
||||||
LMP_VERSION_NAMES = HCI_VERSION_NAMES
|
|
||||||
|
|
||||||
# HCI Packet types
|
# HCI Packet types
|
||||||
HCI_COMMAND_PACKET = 0x01
|
HCI_COMMAND_PACKET = 0x01
|
||||||
HCI_ACL_DATA_PACKET = 0x02
|
HCI_ACL_DATA_PACKET = 0x02
|
||||||
@@ -387,8 +366,8 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
|
|||||||
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
|
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
|
||||||
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
|
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
|
||||||
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
|
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
|
||||||
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
|
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2A
|
||||||
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
|
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2B
|
||||||
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
|
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
|
||||||
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
|
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
|
||||||
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
|
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
|
||||||
@@ -1860,44 +1839,46 @@ class HCI_Object:
|
|||||||
field_type = field_type['parser']
|
field_type = field_type['parser']
|
||||||
|
|
||||||
# Parse the field
|
# Parse the field
|
||||||
if field_type == '*':
|
match field_type:
|
||||||
# The rest of the bytes
|
case '*':
|
||||||
field_value = data[offset:]
|
# The rest of the bytes
|
||||||
return (field_value, len(field_value))
|
field_value = data[offset:]
|
||||||
if field_type == 'v':
|
return (field_value, len(field_value))
|
||||||
# Variable-length bytes field, with 1-byte length at the beginning
|
case 'v':
|
||||||
field_length = data[offset]
|
# Variable-length bytes field, with 1-byte length at the beginning
|
||||||
offset += 1
|
field_length = data[offset]
|
||||||
field_value = data[offset : offset + field_length]
|
offset += 1
|
||||||
return (field_value, field_length + 1)
|
field_value = data[offset : offset + field_length]
|
||||||
if field_type == 1:
|
return (field_value, field_length + 1)
|
||||||
# 8-bit unsigned
|
case 1:
|
||||||
return (data[offset], 1)
|
# 8-bit unsigned
|
||||||
if field_type == -1:
|
return (data[offset], 1)
|
||||||
# 8-bit signed
|
case -1:
|
||||||
return (struct.unpack_from('b', data, offset)[0], 1)
|
# 8-bit signed
|
||||||
if field_type == 2:
|
return (struct.unpack_from('b', data, offset)[0], 1)
|
||||||
# 16-bit unsigned
|
case 2:
|
||||||
return (struct.unpack_from('<H', data, offset)[0], 2)
|
# 16-bit unsigned
|
||||||
if field_type == '>2':
|
return (struct.unpack_from('<H', data, offset)[0], 2)
|
||||||
# 16-bit unsigned big-endian
|
case '>2':
|
||||||
return (struct.unpack_from('>H', data, offset)[0], 2)
|
# 16-bit unsigned big-endian
|
||||||
if field_type == -2:
|
return (struct.unpack_from('>H', data, offset)[0], 2)
|
||||||
# 16-bit signed
|
case -2:
|
||||||
return (struct.unpack_from('<h', data, offset)[0], 2)
|
# 16-bit signed
|
||||||
if field_type == 3:
|
return (struct.unpack_from('<h', data, offset)[0], 2)
|
||||||
# 24-bit unsigned
|
case 3:
|
||||||
padded = data[offset : offset + 3] + bytes([0])
|
# 24-bit unsigned
|
||||||
return (struct.unpack('<I', padded)[0], 3)
|
padded = data[offset : offset + 3] + bytes([0])
|
||||||
if field_type == 4:
|
return (struct.unpack('<I', padded)[0], 3)
|
||||||
# 32-bit unsigned
|
case 4:
|
||||||
return (struct.unpack_from('<I', data, offset)[0], 4)
|
# 32-bit unsigned
|
||||||
if field_type == '>4':
|
return (struct.unpack_from('<I', data, offset)[0], 4)
|
||||||
# 32-bit unsigned big-endian
|
case '>4':
|
||||||
return (struct.unpack_from('>I', data, offset)[0], 4)
|
# 32-bit unsigned big-endian
|
||||||
if isinstance(field_type, int) and 4 < field_type <= 256:
|
return (struct.unpack_from('>I', data, offset)[0], 4)
|
||||||
# Byte array (from 5 up to 256 bytes)
|
case int() if 4 < field_type <= 256:
|
||||||
return (data[offset : offset + field_type], field_type)
|
# Byte array (from 5 up to 256 bytes)
|
||||||
|
return (data[offset : offset + field_type], field_type)
|
||||||
|
|
||||||
if callable(field_type):
|
if callable(field_type):
|
||||||
new_offset, field_value = field_type(data, offset)
|
new_offset, field_value = field_type(data, offset)
|
||||||
return (field_value, new_offset - offset)
|
return (field_value, new_offset - offset)
|
||||||
@@ -1954,60 +1935,58 @@ class HCI_Object:
|
|||||||
|
|
||||||
# Serialize the field
|
# Serialize the field
|
||||||
if serializer:
|
if serializer:
|
||||||
field_bytes = serializer(field_value)
|
return serializer(field_value)
|
||||||
elif field_type == 1:
|
match field_type:
|
||||||
# 8-bit unsigned
|
case 1:
|
||||||
field_bytes = bytes([field_value])
|
# 8-bit unsigned
|
||||||
elif field_type == -1:
|
return bytes([field_value])
|
||||||
# 8-bit signed
|
case -1:
|
||||||
field_bytes = struct.pack('b', field_value)
|
# 8-bit signed
|
||||||
elif field_type == 2:
|
return struct.pack('b', field_value)
|
||||||
# 16-bit unsigned
|
case 2:
|
||||||
field_bytes = struct.pack('<H', field_value)
|
# 16-bit unsigned
|
||||||
elif field_type == '>2':
|
return struct.pack('<H', field_value)
|
||||||
# 16-bit unsigned big-endian
|
case '>2':
|
||||||
field_bytes = struct.pack('>H', field_value)
|
# 16-bit unsigned big-endian
|
||||||
elif field_type == -2:
|
return struct.pack('>H', field_value)
|
||||||
# 16-bit signed
|
case -2:
|
||||||
field_bytes = struct.pack('<h', field_value)
|
# 16-bit signed
|
||||||
elif field_type == 3:
|
return struct.pack('<h', field_value)
|
||||||
# 24-bit unsigned
|
case 3:
|
||||||
field_bytes = struct.pack('<I', field_value)[0:3]
|
# 24-bit unsigned
|
||||||
elif field_type == 4:
|
return struct.pack('<I', field_value)[0:3]
|
||||||
# 32-bit unsigned
|
case 4:
|
||||||
field_bytes = struct.pack('<I', field_value)
|
# 32-bit unsigned
|
||||||
elif field_type == '>4':
|
return struct.pack('<I', field_value)
|
||||||
# 32-bit unsigned big-endian
|
case '>4':
|
||||||
field_bytes = struct.pack('>I', field_value)
|
# 32-bit unsigned big-endian
|
||||||
elif field_type == '*':
|
return struct.pack('>I', field_value)
|
||||||
if isinstance(field_value, int):
|
case '*':
|
||||||
if 0 <= field_value <= 255:
|
if isinstance(field_value, int):
|
||||||
field_bytes = bytes([field_value])
|
if 0 <= field_value <= 255:
|
||||||
|
return bytes([field_value])
|
||||||
|
else:
|
||||||
|
raise InvalidArgumentError('value too large for *-typed field')
|
||||||
else:
|
else:
|
||||||
raise InvalidArgumentError('value too large for *-typed field')
|
return bytes(field_value)
|
||||||
else:
|
case 'v':
|
||||||
|
# Variable-length bytes field, with 1-byte length at the beginning
|
||||||
field_bytes = bytes(field_value)
|
field_bytes = bytes(field_value)
|
||||||
elif field_type == 'v':
|
field_length = len(field_bytes)
|
||||||
# Variable-length bytes field, with 1-byte length at the beginning
|
return bytes([field_length]) + field_bytes
|
||||||
field_bytes = bytes(field_value)
|
if isinstance(field_value, (bytes, bytearray, SupportsBytes)):
|
||||||
field_length = len(field_bytes)
|
|
||||||
field_bytes = bytes([field_length]) + field_bytes
|
|
||||||
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
|
|
||||||
field_value, '__bytes__'
|
|
||||||
):
|
|
||||||
field_bytes = bytes(field_value)
|
field_bytes = bytes(field_value)
|
||||||
if isinstance(field_type, int) and 4 < field_type <= 256:
|
if isinstance(field_type, int) and 4 < field_type <= 256:
|
||||||
# Truncate or pad with zeros if the field is too long or too short
|
# Truncate or pad with zeros if the field is too long or too short
|
||||||
if len(field_bytes) < field_type:
|
if len(field_bytes) < field_type:
|
||||||
field_bytes += bytes(field_type - len(field_bytes))
|
return field_bytes + bytes(field_type - len(field_bytes))
|
||||||
elif len(field_bytes) > field_type:
|
elif len(field_bytes) > field_type:
|
||||||
field_bytes = field_bytes[:field_type]
|
return field_bytes[:field_type]
|
||||||
else:
|
return field_bytes
|
||||||
raise InvalidArgumentError(
|
|
||||||
f"don't know how to serialize type {type(field_value)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return field_bytes
|
raise InvalidArgumentError(
|
||||||
|
f"don't know how to serialize type {type(field_value)}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dict_to_bytes(hci_object, object_fields):
|
def dict_to_bytes(hci_object, object_fields):
|
||||||
@@ -4736,7 +4715,7 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_SyncCommand[HCI_StatusReturnParame
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class HCI_LE_Read_Resolving_List_Size_ReturnParameters(HCI_StatusReturnParameters):
|
class HCI_LE_Read_Resolving_List_Size_ReturnParameters(HCI_StatusReturnParameters):
|
||||||
resolving_list_size: bytes = field(metadata=metadata(1))
|
resolving_list_size: int = field(metadata=metadata(1))
|
||||||
|
|
||||||
|
|
||||||
@HCI_SyncCommand.sync_command(HCI_LE_Read_Resolving_List_Size_ReturnParameters)
|
@HCI_SyncCommand.sync_command(HCI_LE_Read_Resolving_List_Size_ReturnParameters)
|
||||||
|
|||||||
+72
-90
@@ -26,7 +26,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar
|
from typing import Any, ClassVar, Literal, overload
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@@ -68,6 +68,8 @@ class HfpProtocolError(ProtocolError):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class HfpProtocol:
|
class HfpProtocol:
|
||||||
|
MAX_BUFFER_SIZE: ClassVar[int] = 65536
|
||||||
|
|
||||||
dlc: rfcomm.DLC
|
dlc: rfcomm.DLC
|
||||||
buffer: str
|
buffer: str
|
||||||
lines: collections.deque
|
lines: collections.deque
|
||||||
@@ -84,10 +86,19 @@ class HfpProtocol:
|
|||||||
def feed(self, data: bytes | str) -> None:
|
def feed(self, data: bytes | str) -> None:
|
||||||
# Convert the data to a string if needed
|
# Convert the data to a string if needed
|
||||||
if isinstance(data, bytes):
|
if isinstance(data, bytes):
|
||||||
data = data.decode('utf-8')
|
data = data.decode('utf-8', errors='replace')
|
||||||
|
|
||||||
logger.debug(f'<<< Data received: {data}')
|
logger.debug(f'<<< Data received: {data}')
|
||||||
|
|
||||||
|
# Drop incoming data if it would overflow the buffer; keep existing
|
||||||
|
# partial packet state intact so a future clean packet can still parse.
|
||||||
|
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
|
||||||
|
logger.warning(
|
||||||
|
'HFP buffer overflow (>%d bytes), dropping incoming data',
|
||||||
|
self.MAX_BUFFER_SIZE,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Add to the buffer and look for lines
|
# Add to the buffer and look for lines
|
||||||
self.buffer += data
|
self.buffer += data
|
||||||
while (separator := self.buffer.find('\r')) >= 0:
|
while (separator := self.buffer.find('\r')) >= 0:
|
||||||
@@ -420,61 +431,6 @@ class CmeError(enum.IntEnum):
|
|||||||
# Hands-Free Control Interoperability Requirements
|
# Hands-Free Control Interoperability Requirements
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Response codes.
|
|
||||||
RESPONSE_CODES = {
|
|
||||||
"+APLSIRI",
|
|
||||||
"+BAC",
|
|
||||||
"+BCC",
|
|
||||||
"+BCS",
|
|
||||||
"+BIA",
|
|
||||||
"+BIEV",
|
|
||||||
"+BIND",
|
|
||||||
"+BINP",
|
|
||||||
"+BLDN",
|
|
||||||
"+BRSF",
|
|
||||||
"+BTRH",
|
|
||||||
"+BVRA",
|
|
||||||
"+CCWA",
|
|
||||||
"+CHLD",
|
|
||||||
"+CHUP",
|
|
||||||
"+CIND",
|
|
||||||
"+CLCC",
|
|
||||||
"+CLIP",
|
|
||||||
"+CMEE",
|
|
||||||
"+CMER",
|
|
||||||
"+CNUM",
|
|
||||||
"+COPS",
|
|
||||||
"+IPHONEACCEV",
|
|
||||||
"+NREC",
|
|
||||||
"+VGM",
|
|
||||||
"+VGS",
|
|
||||||
"+VTS",
|
|
||||||
"+XAPL",
|
|
||||||
"A",
|
|
||||||
"D",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Unsolicited responses and statuses.
|
|
||||||
UNSOLICITED_CODES = {
|
|
||||||
"+APLSIRI",
|
|
||||||
"+BCS",
|
|
||||||
"+BIND",
|
|
||||||
"+BSIR",
|
|
||||||
"+BTRH",
|
|
||||||
"+BVRA",
|
|
||||||
"+CCWA",
|
|
||||||
"+CIEV",
|
|
||||||
"+CLIP",
|
|
||||||
"+VGM",
|
|
||||||
"+VGS",
|
|
||||||
"BLACKLISTED",
|
|
||||||
"BUSY",
|
|
||||||
"DELAYED",
|
|
||||||
"NO ANSWER",
|
|
||||||
"NO CARRIER",
|
|
||||||
"RING",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Status codes
|
# Status codes
|
||||||
STATUS_CODES = {
|
STATUS_CODES = {
|
||||||
"+CME ERROR",
|
"+CME ERROR",
|
||||||
@@ -727,12 +683,9 @@ class HfProtocol(utils.EventEmitter):
|
|||||||
|
|
||||||
dlc: rfcomm.DLC
|
dlc: rfcomm.DLC
|
||||||
command_lock: asyncio.Lock
|
command_lock: asyncio.Lock
|
||||||
if TYPE_CHECKING:
|
pending_command: str | None = None
|
||||||
response_queue: asyncio.Queue[AtResponse]
|
response_queue: asyncio.Queue[AtResponse]
|
||||||
unsolicited_queue: asyncio.Queue[AtResponse | None]
|
unsolicited_queue: asyncio.Queue[AtResponse | None]
|
||||||
else:
|
|
||||||
response_queue: asyncio.Queue
|
|
||||||
unsolicited_queue: asyncio.Queue
|
|
||||||
read_buffer: bytearray
|
read_buffer: bytearray
|
||||||
active_codec: AudioCodec
|
active_codec: AudioCodec
|
||||||
|
|
||||||
@@ -805,16 +758,39 @@ class HfProtocol(utils.EventEmitter):
|
|||||||
self.read_buffer = self.read_buffer[trailer + 2 :]
|
self.read_buffer = self.read_buffer[trailer + 2 :]
|
||||||
|
|
||||||
# Forward the received code to the correct queue.
|
# Forward the received code to the correct queue.
|
||||||
if self.command_lock.locked() and (
|
if self.pending_command and (
|
||||||
response.code in STATUS_CODES or response.code in RESPONSE_CODES
|
response.code in STATUS_CODES or response.code in self.pending_command
|
||||||
):
|
):
|
||||||
self.response_queue.put_nowait(response)
|
self.response_queue.put_nowait(response)
|
||||||
elif response.code in UNSOLICITED_CODES:
|
|
||||||
self.unsolicited_queue.put_nowait(response)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self.unsolicited_queue.put_nowait(response)
|
||||||
f"dropping unexpected response with code '{response.code}'"
|
|
||||||
)
|
@overload
|
||||||
|
async def execute_command(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
timeout: float = 1.0,
|
||||||
|
*,
|
||||||
|
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def execute_command(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
timeout: float = 1.0,
|
||||||
|
*,
|
||||||
|
response_type: Literal[AtResponseType.SINGLE],
|
||||||
|
) -> AtResponse: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def execute_command(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
timeout: float = 1.0,
|
||||||
|
*,
|
||||||
|
response_type: Literal[AtResponseType.MULTIPLE],
|
||||||
|
) -> list[AtResponse]: ...
|
||||||
|
|
||||||
async def execute_command(
|
async def execute_command(
|
||||||
self,
|
self,
|
||||||
@@ -835,27 +811,34 @@ class HfProtocol(utils.EventEmitter):
|
|||||||
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
|
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
|
||||||
ProtocolError: the status is not OK.
|
ProtocolError: the status is not OK.
|
||||||
"""
|
"""
|
||||||
async with self.command_lock:
|
try:
|
||||||
logger.debug(f">>> {cmd}")
|
async with self.command_lock:
|
||||||
self.dlc.write(cmd + '\r')
|
self.pending_command = cmd
|
||||||
responses: list[AtResponse] = []
|
logger.debug(f">>> {cmd}")
|
||||||
|
self.dlc.write(cmd + '\r')
|
||||||
|
responses: list[AtResponse] = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self.response_queue.get(), timeout=timeout
|
self.response_queue.get(), timeout=timeout
|
||||||
)
|
)
|
||||||
if result.code == 'OK':
|
if result.code == 'OK':
|
||||||
if response_type == AtResponseType.SINGLE and len(responses) != 1:
|
if (
|
||||||
raise HfpProtocolError("NO ANSWER")
|
response_type == AtResponseType.SINGLE
|
||||||
|
and len(responses) != 1
|
||||||
|
):
|
||||||
|
raise HfpProtocolError("NO ANSWER")
|
||||||
|
|
||||||
if response_type == AtResponseType.MULTIPLE:
|
if response_type == AtResponseType.MULTIPLE:
|
||||||
return responses
|
return responses
|
||||||
if response_type == AtResponseType.SINGLE:
|
if response_type == AtResponseType.SINGLE:
|
||||||
return responses[0]
|
return responses[0]
|
||||||
return None
|
return None
|
||||||
if result.code in STATUS_CODES:
|
if result.code in STATUS_CODES:
|
||||||
raise HfpProtocolError(result.code)
|
raise HfpProtocolError(result.code)
|
||||||
responses.append(result)
|
responses.append(result)
|
||||||
|
finally:
|
||||||
|
self.pending_command = None
|
||||||
|
|
||||||
async def initiate_slc(self):
|
async def initiate_slc(self):
|
||||||
"""4.2.1 Service Level Connection Initialization."""
|
"""4.2.1 Service Level Connection Initialization."""
|
||||||
@@ -1067,7 +1050,6 @@ class HfProtocol(utils.EventEmitter):
|
|||||||
responses = await self.execute_command(
|
responses = await self.execute_command(
|
||||||
"AT+CLCC", response_type=AtResponseType.MULTIPLE
|
"AT+CLCC", response_type=AtResponseType.MULTIPLE
|
||||||
)
|
)
|
||||||
assert isinstance(responses, list)
|
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
for response in responses:
|
for response in responses:
|
||||||
|
|||||||
+41
-17
@@ -22,7 +22,7 @@ import collections
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
|
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||||
|
|
||||||
from bumble import drivers, hci, utils
|
from bumble import drivers, hci, utils
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
@@ -692,10 +692,8 @@ class Host(utils.EventEmitter):
|
|||||||
finally:
|
finally:
|
||||||
self.pending_command = None
|
self.pending_command = None
|
||||||
self.pending_response = None
|
self.pending_response = None
|
||||||
if (
|
if response is None or (
|
||||||
response is not None
|
response.num_hci_command_packets and self.command_semaphore.locked()
|
||||||
and response.num_hci_command_packets
|
|
||||||
and self.command_semaphore.locked()
|
|
||||||
):
|
):
|
||||||
self.command_semaphore.release()
|
self.command_semaphore.release()
|
||||||
|
|
||||||
@@ -1002,18 +1000,19 @@ class Host(utils.EventEmitter):
|
|||||||
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
|
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
|
||||||
|
|
||||||
# If the packet is a command, invoke the handler for this packet
|
# If the packet is a command, invoke the handler for this packet
|
||||||
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
|
match packet:
|
||||||
self.on_hci_command_packet(cast(hci.HCI_Command, packet))
|
case hci.HCI_Command():
|
||||||
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
|
self.on_hci_command_packet(packet)
|
||||||
self.on_hci_event_packet(cast(hci.HCI_Event, packet))
|
case hci.HCI_Event():
|
||||||
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
|
self.on_hci_event_packet(packet)
|
||||||
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
|
case hci.HCI_AclDataPacket():
|
||||||
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
|
self.on_hci_acl_data_packet(packet)
|
||||||
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
|
case hci.HCI_SynchronousDataPacket():
|
||||||
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
|
self.on_hci_sco_data_packet(packet)
|
||||||
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
|
case hci.HCI_IsoDataPacket():
|
||||||
else:
|
self.on_hci_iso_data_packet(packet)
|
||||||
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
|
case _:
|
||||||
|
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
|
||||||
|
|
||||||
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
|
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
|
||||||
logger.warning(f'!!! unexpected command packet: {command}')
|
logger.warning(f'!!! unexpected command packet: {command}')
|
||||||
@@ -1659,6 +1658,19 @@ class Host(utils.EventEmitter):
|
|||||||
'connection_encryption_failure', event.connection_handle, event.status
|
'connection_encryption_failure', event.connection_handle, event.status
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_hci_read_remote_supported_features_complete_event(
|
||||||
|
self, event: hci.HCI_Read_Remote_Supported_Features_Complete_Event
|
||||||
|
) -> None:
|
||||||
|
# Notify the client
|
||||||
|
self.emit(
|
||||||
|
'classic_remote_features',
|
||||||
|
event.connection_handle,
|
||||||
|
event.status,
|
||||||
|
int.from_bytes(event.lmp_features, 'little'),
|
||||||
|
0, # page number
|
||||||
|
0, # max page number
|
||||||
|
)
|
||||||
|
|
||||||
def on_hci_encryption_change_v2_event(
|
def on_hci_encryption_change_v2_event(
|
||||||
self, event: hci.HCI_Encryption_Change_V2_Event
|
self, event: hci.HCI_Encryption_Change_V2_Event
|
||||||
):
|
):
|
||||||
@@ -1815,6 +1827,18 @@ class Host(utils.EventEmitter):
|
|||||||
rssi,
|
rssi,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_hci_read_remote_extended_features_complete_event(
|
||||||
|
self, event: hci.HCI_Read_Remote_Extended_Features_Complete_Event
|
||||||
|
):
|
||||||
|
self.emit(
|
||||||
|
'classic_remote_features',
|
||||||
|
event.connection_handle,
|
||||||
|
event.status,
|
||||||
|
int.from_bytes(event.extended_lmp_features, 'little'),
|
||||||
|
event.page_number,
|
||||||
|
event.maximum_page_number,
|
||||||
|
)
|
||||||
|
|
||||||
def on_hci_extended_inquiry_result_event(
|
def on_hci_extended_inquiry_result_event(
|
||||||
self, event: hci.HCI_Extended_Inquiry_Result_Event
|
self, event: hci.HCI_Extended_Inquiry_Result_Event
|
||||||
):
|
):
|
||||||
|
|||||||
+29
-29
@@ -27,6 +27,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@@ -248,29 +249,26 @@ class JsonKeyStore(KeyStore):
|
|||||||
DEFAULT_NAMESPACE = '__DEFAULT__'
|
DEFAULT_NAMESPACE = '__DEFAULT__'
|
||||||
DEFAULT_BASE_NAME = "keys"
|
DEFAULT_BASE_NAME = "keys"
|
||||||
|
|
||||||
def __init__(self, namespace, filename=None):
|
def __init__(
|
||||||
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
|
self, namespace: str | None = None, filename: str | None = None
|
||||||
|
) -> None:
|
||||||
|
self.namespace = namespace or self.DEFAULT_NAMESPACE
|
||||||
|
|
||||||
if filename is None:
|
if filename:
|
||||||
# Use a default for the current user
|
self.filename = pathlib.Path(filename).resolve()
|
||||||
|
self.directory_name = self.filename.parent
|
||||||
# Import here because this may not exist on all platforms
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
import appdirs
|
|
||||||
|
|
||||||
self.directory_name = os.path.join(
|
|
||||||
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
|
|
||||||
)
|
|
||||||
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
|
|
||||||
json_filename = (
|
|
||||||
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
|
|
||||||
)
|
|
||||||
self.filename = os.path.join(self.directory_name, json_filename)
|
|
||||||
else:
|
else:
|
||||||
self.filename = filename
|
import platformdirs # Deferred import
|
||||||
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
|
|
||||||
|
|
||||||
logger.debug(f'JSON keystore: {self.filename}')
|
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR)
|
||||||
|
self.directory_name = base_dir / self.KEYS_DIR
|
||||||
|
|
||||||
|
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
|
||||||
|
safe_name = base_name.lower().replace(':', '-').replace('/', '-')
|
||||||
|
|
||||||
|
self.filename = self.directory_name / f"{safe_name}.json"
|
||||||
|
|
||||||
|
logger.debug('JSON keystore: %s', self.filename)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_device(
|
def from_device(
|
||||||
@@ -293,7 +291,9 @@ class JsonKeyStore(KeyStore):
|
|||||||
|
|
||||||
return cls(namespace, filename)
|
return cls(namespace, filename)
|
||||||
|
|
||||||
async def load(self):
|
async def load(
|
||||||
|
self,
|
||||||
|
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
|
||||||
# Try to open the file, without failing. If the file does not exist, it
|
# Try to open the file, without failing. If the file does not exist, it
|
||||||
# will be created upon saving.
|
# will be created upon saving.
|
||||||
try:
|
try:
|
||||||
@@ -312,17 +312,17 @@ class JsonKeyStore(KeyStore):
|
|||||||
return next(iter(db.items()))
|
return next(iter(db.items()))
|
||||||
|
|
||||||
# Finally, just create an empty key map for the namespace
|
# Finally, just create an empty key map for the namespace
|
||||||
key_map = {}
|
key_map: dict[str, dict[str, Any]] = {}
|
||||||
db[self.namespace] = key_map
|
db[self.namespace] = key_map
|
||||||
return (db, key_map)
|
return (db, key_map)
|
||||||
|
|
||||||
async def save(self, db):
|
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None:
|
||||||
# Create the directory if it doesn't exist
|
# Create the directory if it doesn't exist
|
||||||
if not os.path.exists(self.directory_name):
|
if not self.directory_name.exists():
|
||||||
os.makedirs(self.directory_name, exist_ok=True)
|
self.directory_name.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Save to a temporary file
|
# Save to a temporary file
|
||||||
temp_filename = self.filename + '.tmp'
|
temp_filename = self.filename.with_name(self.filename.name + ".tmp")
|
||||||
with open(temp_filename, 'w', encoding='utf-8') as output:
|
with open(temp_filename, 'w', encoding='utf-8') as output:
|
||||||
json.dump(db, output, sort_keys=True, indent=4)
|
json.dump(db, output, sort_keys=True, indent=4)
|
||||||
|
|
||||||
@@ -334,16 +334,16 @@ class JsonKeyStore(KeyStore):
|
|||||||
del key_map[name]
|
del key_map[name]
|
||||||
await self.save(db)
|
await self.save(db)
|
||||||
|
|
||||||
async def update(self, name, keys):
|
async def update(self, name: str, keys: PairingKeys) -> None:
|
||||||
db, key_map = await self.load()
|
db, key_map = await self.load()
|
||||||
key_map.setdefault(name, {}).update(keys.to_dict())
|
key_map.setdefault(name, {}).update(keys.to_dict())
|
||||||
await self.save(db)
|
await self.save(db)
|
||||||
|
|
||||||
async def get_all(self):
|
async def get_all(self) -> list[tuple[str, PairingKeys]]:
|
||||||
_, key_map = await self.load()
|
_, key_map = await self.load()
|
||||||
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
|
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
|
||||||
|
|
||||||
async def delete_all(self):
|
async def delete_all(self) -> None:
|
||||||
db, key_map = await self.load()
|
db, key_map = await self.load()
|
||||||
key_map.clear()
|
key_map.clear()
|
||||||
await self.save(db)
|
await self.save(db)
|
||||||
|
|||||||
@@ -198,3 +198,24 @@ class CisTerminateInd(ControlPdu):
|
|||||||
cig_id: int
|
cig_id: int
|
||||||
cis_id: int
|
cis_id: int
|
||||||
error_code: int
|
error_code: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FeatureReq(ControlPdu):
|
||||||
|
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
|
||||||
|
|
||||||
|
feature_set: bytes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FeatureRsp(ControlPdu):
|
||||||
|
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
|
||||||
|
|
||||||
|
feature_set: bytes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PeripheralFeatureReq(ControlPdu):
|
||||||
|
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
|
||||||
|
|
||||||
|
feature_set: bytes
|
||||||
|
|||||||
@@ -322,3 +322,38 @@ class LmpNameRes(Packet):
|
|||||||
name_offset: int = field(metadata=hci.metadata(2))
|
name_offset: int = field(metadata=hci.metadata(2))
|
||||||
name_length: int = field(metadata=hci.metadata(3))
|
name_length: int = field(metadata=hci.metadata(3))
|
||||||
name_fregment: bytes = field(metadata=hci.metadata('*'))
|
name_fregment: bytes = field(metadata=hci.metadata('*'))
|
||||||
|
|
||||||
|
|
||||||
|
@Packet.subclass
|
||||||
|
@dataclass
|
||||||
|
class LmpFeaturesReq(Packet):
|
||||||
|
opcode = Opcode.LMP_FEATURES_REQ
|
||||||
|
|
||||||
|
features: bytes = field(metadata=hci.metadata(8))
|
||||||
|
|
||||||
|
|
||||||
|
@Packet.subclass
|
||||||
|
@dataclass
|
||||||
|
class LmpFeaturesRes(Packet):
|
||||||
|
opcode = Opcode.LMP_FEATURES_RES
|
||||||
|
|
||||||
|
features: bytes = field(metadata=hci.metadata(8))
|
||||||
|
|
||||||
|
|
||||||
|
@Packet.subclass
|
||||||
|
@dataclass
|
||||||
|
class LmpFeaturesReqExt(Packet):
|
||||||
|
opcode = Opcode.LMP_FEATURES_REQ_EXT
|
||||||
|
|
||||||
|
features_page: int = field(metadata=hci.metadata(1))
|
||||||
|
features: bytes = field(metadata=hci.metadata(8))
|
||||||
|
|
||||||
|
|
||||||
|
@Packet.subclass
|
||||||
|
@dataclass
|
||||||
|
class LmpFeaturesResExt(Packet):
|
||||||
|
opcode = Opcode.LMP_FEATURES_RES_EXT
|
||||||
|
|
||||||
|
features_page: int = field(metadata=hci.metadata(1))
|
||||||
|
max_features_page: int = field(metadata=hci.metadata(1))
|
||||||
|
features: bytes = field(metadata=hci.metadata(8))
|
||||||
|
|||||||
+10
-19
@@ -21,18 +21,9 @@ import enum
|
|||||||
import secrets
|
import secrets
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from bumble import hci
|
from bumble import hci, smp
|
||||||
from bumble.core import AdvertisingData, LeRole
|
from bumble.core import AdvertisingData, LeRole
|
||||||
from bumble.smp import (
|
from bumble.smp import (
|
||||||
SMP_DISPLAY_ONLY_IO_CAPABILITY,
|
|
||||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
|
|
||||||
SMP_ENC_KEY_DISTRIBUTION_FLAG,
|
|
||||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
|
||||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
|
|
||||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
|
|
||||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
|
||||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
|
||||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
|
||||||
OobContext,
|
OobContext,
|
||||||
OobLegacyContext,
|
OobLegacyContext,
|
||||||
OobSharedData,
|
OobSharedData,
|
||||||
@@ -96,11 +87,11 @@ class PairingDelegate:
|
|||||||
# These are defined abstractly, and can be mapped to specific Classic pairing
|
# These are defined abstractly, and can be mapped to specific Classic pairing
|
||||||
# and/or SMP constants.
|
# and/or SMP constants.
|
||||||
class IoCapability(enum.IntEnum):
|
class IoCapability(enum.IntEnum):
|
||||||
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT
|
||||||
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
|
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY
|
||||||
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
|
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY
|
||||||
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
|
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO
|
||||||
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
|
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY
|
||||||
|
|
||||||
# Direct names for backward compatibility.
|
# Direct names for backward compatibility.
|
||||||
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
|
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
|
||||||
@@ -111,10 +102,10 @@ class PairingDelegate:
|
|||||||
|
|
||||||
# Key Distribution [LE only]
|
# Key Distribution [LE only]
|
||||||
class KeyDistribution(enum.IntFlag):
|
class KeyDistribution(enum.IntFlag):
|
||||||
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
|
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY
|
||||||
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
|
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY
|
||||||
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
|
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY
|
||||||
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
|
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY
|
||||||
|
|
||||||
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
|
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
|
||||||
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
|
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
|
||||||
|
|||||||
+37
-39
@@ -664,46 +664,44 @@ class AudioStreamControlService(gatt.TemplateService):
|
|||||||
responses = []
|
responses = []
|
||||||
logger.debug(f'*** ASCS Write {operation} ***')
|
logger.debug(f'*** ASCS Write {operation} ***')
|
||||||
|
|
||||||
if isinstance(operation, ASE_Config_Codec):
|
match operation:
|
||||||
for ase_id, *args in zip(
|
case ASE_Config_Codec():
|
||||||
operation.ase_id,
|
for ase_id, *args in zip(
|
||||||
operation.target_latency,
|
operation.ase_id,
|
||||||
operation.target_phy,
|
operation.target_latency,
|
||||||
operation.codec_id,
|
operation.target_phy,
|
||||||
operation.codec_specific_configuration,
|
operation.codec_id,
|
||||||
|
operation.codec_specific_configuration,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
case ASE_Config_QOS():
|
||||||
|
for ase_id, *args in zip(
|
||||||
|
operation.ase_id,
|
||||||
|
operation.cig_id,
|
||||||
|
operation.cis_id,
|
||||||
|
operation.sdu_interval,
|
||||||
|
operation.framing,
|
||||||
|
operation.phy,
|
||||||
|
operation.max_sdu,
|
||||||
|
operation.retransmission_number,
|
||||||
|
operation.max_transport_latency,
|
||||||
|
operation.presentation_delay,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
case ASE_Enable() | ASE_Update_Metadata():
|
||||||
|
for ase_id, *args in zip(
|
||||||
|
operation.ase_id,
|
||||||
|
operation.metadata,
|
||||||
|
):
|
||||||
|
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||||
|
case (
|
||||||
|
ASE_Receiver_Start_Ready()
|
||||||
|
| ASE_Disable()
|
||||||
|
| ASE_Receiver_Stop_Ready()
|
||||||
|
| ASE_Release()
|
||||||
):
|
):
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
for ase_id in operation.ase_id:
|
||||||
elif isinstance(operation, ASE_Config_QOS):
|
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||||
for ase_id, *args in zip(
|
|
||||||
operation.ase_id,
|
|
||||||
operation.cig_id,
|
|
||||||
operation.cis_id,
|
|
||||||
operation.sdu_interval,
|
|
||||||
operation.framing,
|
|
||||||
operation.phy,
|
|
||||||
operation.max_sdu,
|
|
||||||
operation.retransmission_number,
|
|
||||||
operation.max_transport_latency,
|
|
||||||
operation.presentation_delay,
|
|
||||||
):
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
|
||||||
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
|
|
||||||
for ase_id, *args in zip(
|
|
||||||
operation.ase_id,
|
|
||||||
operation.metadata,
|
|
||||||
):
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
|
||||||
elif isinstance(
|
|
||||||
operation,
|
|
||||||
(
|
|
||||||
ASE_Receiver_Start_Ready,
|
|
||||||
ASE_Disable,
|
|
||||||
ASE_Receiver_Stop_Ready,
|
|
||||||
ASE_Release,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
for ase_id in operation.ase_id:
|
|
||||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
|
||||||
|
|
||||||
control_point_notification = bytes(
|
control_point_notification = bytes(
|
||||||
[operation.op_code, len(responses)]
|
[operation.op_code, len(responses)]
|
||||||
|
|||||||
+12
-11
@@ -333,17 +333,18 @@ class CodecSpecificCapabilities:
|
|||||||
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
||||||
offset += length - 1
|
offset += length - 1
|
||||||
|
|
||||||
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
|
match type:
|
||||||
supported_sampling_frequencies = SupportedSamplingFrequency(value)
|
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
|
||||||
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
|
supported_sampling_frequencies = SupportedSamplingFrequency(value)
|
||||||
supported_frame_durations = SupportedFrameDuration(value)
|
case CodecSpecificCapabilities.Type.FRAME_DURATION:
|
||||||
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
supported_frame_durations = SupportedFrameDuration(value)
|
||||||
supported_audio_channel_count = bits_to_channel_counts(value)
|
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
||||||
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
supported_audio_channel_count = bits_to_channel_counts(value)
|
||||||
min_octets_per_sample = value & 0xFFFF
|
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
||||||
max_octets_per_sample = value >> 16
|
min_octets_per_sample = value & 0xFFFF
|
||||||
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
|
max_octets_per_sample = value >> 16
|
||||||
supported_max_codec_frames_per_sdu = value
|
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
|
||||||
|
supported_max_codec_frames_per_sdu = value
|
||||||
|
|
||||||
# It is expected here that if some fields are missing, an error should be raised.
|
# It is expected here that if some fields are missing, an error should be raised.
|
||||||
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||||
|
|||||||
@@ -55,14 +55,15 @@ class GenericAccessService(TemplateService):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
|
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
|
||||||
):
|
):
|
||||||
if isinstance(appearance, int):
|
match appearance:
|
||||||
appearance_int = appearance
|
case int():
|
||||||
elif isinstance(appearance, tuple):
|
appearance_int = appearance
|
||||||
appearance_int = (appearance[0] << 6) | appearance[1]
|
case tuple():
|
||||||
elif isinstance(appearance, Appearance):
|
appearance_int = (appearance[0] << 6) | appearance[1]
|
||||||
appearance_int = int(appearance)
|
case Appearance():
|
||||||
else:
|
appearance_int = int(appearance)
|
||||||
raise TypeError()
|
case _:
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
self.device_name_characteristic = Characteristic(
|
self.device_name_characteristic = Characteristic(
|
||||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||||
|
|||||||
+524
-498
File diff suppressed because it is too large
Load Diff
+266
-265
@@ -31,14 +31,14 @@ from collections.abc import Awaitable, Callable, Sequence
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
|
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
|
||||||
|
|
||||||
from bumble import crypto, utils
|
from bumble import crypto, hci, utils
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.core import (
|
from bumble.core import (
|
||||||
AdvertisingData,
|
AdvertisingData,
|
||||||
InvalidArgumentError,
|
InvalidArgumentError,
|
||||||
|
InvalidPacketError,
|
||||||
PhysicalTransport,
|
PhysicalTransport,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
name_or_number,
|
|
||||||
)
|
)
|
||||||
from bumble.hci import (
|
from bumble.hci import (
|
||||||
Address,
|
Address,
|
||||||
@@ -46,7 +46,6 @@ from bumble.hci import (
|
|||||||
HCI_LE_Enable_Encryption_Command,
|
HCI_LE_Enable_Encryption_Command,
|
||||||
HCI_Object,
|
HCI_Object,
|
||||||
Role,
|
Role,
|
||||||
key_with_value,
|
|
||||||
metadata,
|
metadata,
|
||||||
)
|
)
|
||||||
from bumble.keys import PairingKeys
|
from bumble.keys import PairingKeys
|
||||||
@@ -71,115 +70,125 @@ logger = logging.getLogger(__name__)
|
|||||||
SMP_CID = 0x06
|
SMP_CID = 0x06
|
||||||
SMP_BR_CID = 0x07
|
SMP_BR_CID = 0x07
|
||||||
|
|
||||||
SMP_PAIRING_REQUEST_COMMAND = 0x01
|
class CommandCode(hci.SpecableEnum):
|
||||||
SMP_PAIRING_RESPONSE_COMMAND = 0x02
|
PAIRING_REQUEST = 0x01
|
||||||
SMP_PAIRING_CONFIRM_COMMAND = 0x03
|
PAIRING_RESPONSE = 0x02
|
||||||
SMP_PAIRING_RANDOM_COMMAND = 0x04
|
PAIRING_CONFIRM = 0x03
|
||||||
SMP_PAIRING_FAILED_COMMAND = 0x05
|
PAIRING_RANDOM = 0x04
|
||||||
SMP_ENCRYPTION_INFORMATION_COMMAND = 0x06
|
PAIRING_FAILED = 0x05
|
||||||
SMP_MASTER_IDENTIFICATION_COMMAND = 0x07
|
ENCRYPTION_INFORMATION = 0x06
|
||||||
SMP_IDENTITY_INFORMATION_COMMAND = 0x08
|
MASTER_IDENTIFICATION = 0x07
|
||||||
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND = 0x09
|
IDENTITY_INFORMATION = 0x08
|
||||||
SMP_SIGNING_INFORMATION_COMMAND = 0x0A
|
IDENTITY_ADDRESS_INFORMATION = 0x09
|
||||||
SMP_SECURITY_REQUEST_COMMAND = 0x0B
|
SIGNING_INFORMATION = 0x0A
|
||||||
SMP_PAIRING_PUBLIC_KEY_COMMAND = 0x0C
|
SECURITY_REQUEST = 0x0B
|
||||||
SMP_PAIRING_DHKEY_CHECK_COMMAND = 0x0D
|
PAIRING_PUBLIC_KEY = 0x0C
|
||||||
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND = 0x0E
|
PAIRING_DHKEY_CHECK = 0x0D
|
||||||
|
PAIRING_KEYPRESS_NOTIFICATION = 0x0E
|
||||||
|
|
||||||
SMP_COMMAND_NAMES = {
|
|
||||||
SMP_PAIRING_REQUEST_COMMAND: 'SMP_PAIRING_REQUEST_COMMAND',
|
|
||||||
SMP_PAIRING_RESPONSE_COMMAND: 'SMP_PAIRING_RESPONSE_COMMAND',
|
|
||||||
SMP_PAIRING_CONFIRM_COMMAND: 'SMP_PAIRING_CONFIRM_COMMAND',
|
|
||||||
SMP_PAIRING_RANDOM_COMMAND: 'SMP_PAIRING_RANDOM_COMMAND',
|
|
||||||
SMP_PAIRING_FAILED_COMMAND: 'SMP_PAIRING_FAILED_COMMAND',
|
|
||||||
SMP_ENCRYPTION_INFORMATION_COMMAND: 'SMP_ENCRYPTION_INFORMATION_COMMAND',
|
|
||||||
SMP_MASTER_IDENTIFICATION_COMMAND: 'SMP_MASTER_IDENTIFICATION_COMMAND',
|
|
||||||
SMP_IDENTITY_INFORMATION_COMMAND: 'SMP_IDENTITY_INFORMATION_COMMAND',
|
|
||||||
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND: 'SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND',
|
|
||||||
SMP_SIGNING_INFORMATION_COMMAND: 'SMP_SIGNING_INFORMATION_COMMAND',
|
|
||||||
SMP_SECURITY_REQUEST_COMMAND: 'SMP_SECURITY_REQUEST_COMMAND',
|
|
||||||
SMP_PAIRING_PUBLIC_KEY_COMMAND: 'SMP_PAIRING_PUBLIC_KEY_COMMAND',
|
|
||||||
SMP_PAIRING_DHKEY_CHECK_COMMAND: 'SMP_PAIRING_DHKEY_CHECK_COMMAND',
|
|
||||||
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND: 'SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND'
|
|
||||||
}
|
|
||||||
|
|
||||||
SMP_DISPLAY_ONLY_IO_CAPABILITY = 0x00
|
class IoCapability(hci.SpecableEnum):
|
||||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY = 0x01
|
DISPLAY_ONLY = 0x00
|
||||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY = 0x02
|
DISPLAY_YES_NO = 0x01
|
||||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = 0x03
|
KEYBOARD_ONLY = 0x02
|
||||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = 0x04
|
NO_INPUT_NO_OUTPUT = 0x03
|
||||||
|
KEYBOARD_DISPLAY = 0x04
|
||||||
|
|
||||||
SMP_IO_CAPABILITY_NAMES = {
|
SMP_DISPLAY_ONLY_IO_CAPABILITY = IoCapability.DISPLAY_ONLY
|
||||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: 'SMP_DISPLAY_ONLY_IO_CAPABILITY',
|
SMP_DISPLAY_YES_NO_IO_CAPABILITY = IoCapability.DISPLAY_YES_NO
|
||||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: 'SMP_DISPLAY_YES_NO_IO_CAPABILITY',
|
SMP_KEYBOARD_ONLY_IO_CAPABILITY = IoCapability.KEYBOARD_ONLY
|
||||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: 'SMP_KEYBOARD_ONLY_IO_CAPABILITY',
|
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = IoCapability.NO_INPUT_NO_OUTPUT
|
||||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: 'SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY',
|
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = IoCapability.KEYBOARD_DISPLAY
|
||||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: 'SMP_KEYBOARD_DISPLAY_IO_CAPABILITY'
|
|
||||||
}
|
|
||||||
|
|
||||||
SMP_PASSKEY_ENTRY_FAILED_ERROR = 0x01
|
class ErrorCode(hci.SpecableEnum):
|
||||||
SMP_OOB_NOT_AVAILABLE_ERROR = 0x02
|
PASSKEY_ENTRY_FAILED = 0x01
|
||||||
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = 0x03
|
OOB_NOT_AVAILABLE = 0x02
|
||||||
SMP_CONFIRM_VALUE_FAILED_ERROR = 0x04
|
AUTHENTICATION_REQUIREMENTS = 0x03
|
||||||
SMP_PAIRING_NOT_SUPPORTED_ERROR = 0x05
|
CONFIRM_VALUE_FAILED = 0x04
|
||||||
SMP_ENCRYPTION_KEY_SIZE_ERROR = 0x06
|
PAIRING_NOT_SUPPORTED = 0x05
|
||||||
SMP_COMMAND_NOT_SUPPORTED_ERROR = 0x07
|
ENCRYPTION_KEY_SIZE = 0x06
|
||||||
SMP_UNSPECIFIED_REASON_ERROR = 0x08
|
COMMAND_NOT_SUPPORTED = 0x07
|
||||||
SMP_REPEATED_ATTEMPTS_ERROR = 0x09
|
UNSPECIFIED_REASON = 0x08
|
||||||
SMP_INVALID_PARAMETERS_ERROR = 0x0A
|
REPEATED_ATTEMPTS = 0x09
|
||||||
SMP_DHKEY_CHECK_FAILED_ERROR = 0x0B
|
INVALID_PARAMETERS = 0x0A
|
||||||
SMP_NUMERIC_COMPARISON_FAILED_ERROR = 0x0C
|
DHKEY_CHECK_FAILED = 0x0B
|
||||||
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = 0x0D
|
NUMERIC_COMPARISON_FAILED = 0x0C
|
||||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = 0x0E
|
BD_EDR_PAIRING_IN_PROGRESS = 0x0D
|
||||||
|
CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED = 0x0E
|
||||||
|
|
||||||
SMP_ERROR_NAMES = {
|
SMP_PASSKEY_ENTRY_FAILED_ERROR = ErrorCode.PASSKEY_ENTRY_FAILED
|
||||||
SMP_PASSKEY_ENTRY_FAILED_ERROR: 'SMP_PASSKEY_ENTRY_FAILED_ERROR',
|
SMP_OOB_NOT_AVAILABLE_ERROR = ErrorCode.OOB_NOT_AVAILABLE
|
||||||
SMP_OOB_NOT_AVAILABLE_ERROR: 'SMP_OOB_NOT_AVAILABLE_ERROR',
|
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = ErrorCode.AUTHENTICATION_REQUIREMENTS
|
||||||
SMP_AUTHENTICATION_REQUIREMENTS_ERROR: 'SMP_AUTHENTICATION_REQUIREMENTS_ERROR',
|
SMP_CONFIRM_VALUE_FAILED_ERROR = ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
SMP_CONFIRM_VALUE_FAILED_ERROR: 'SMP_CONFIRM_VALUE_FAILED_ERROR',
|
SMP_PAIRING_NOT_SUPPORTED_ERROR = ErrorCode.PAIRING_NOT_SUPPORTED
|
||||||
SMP_PAIRING_NOT_SUPPORTED_ERROR: 'SMP_PAIRING_NOT_SUPPORTED_ERROR',
|
SMP_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.ENCRYPTION_KEY_SIZE
|
||||||
SMP_ENCRYPTION_KEY_SIZE_ERROR: 'SMP_ENCRYPTION_KEY_SIZE_ERROR',
|
SMP_COMMAND_NOT_SUPPORTED_ERROR = ErrorCode.COMMAND_NOT_SUPPORTED
|
||||||
SMP_COMMAND_NOT_SUPPORTED_ERROR: 'SMP_COMMAND_NOT_SUPPORTED_ERROR',
|
SMP_UNSPECIFIED_REASON_ERROR = ErrorCode.UNSPECIFIED_REASON
|
||||||
SMP_UNSPECIFIED_REASON_ERROR: 'SMP_UNSPECIFIED_REASON_ERROR',
|
SMP_REPEATED_ATTEMPTS_ERROR = ErrorCode.REPEATED_ATTEMPTS
|
||||||
SMP_REPEATED_ATTEMPTS_ERROR: 'SMP_REPEATED_ATTEMPTS_ERROR',
|
SMP_INVALID_PARAMETERS_ERROR = ErrorCode.INVALID_PARAMETERS
|
||||||
SMP_INVALID_PARAMETERS_ERROR: 'SMP_INVALID_PARAMETERS_ERROR',
|
SMP_DHKEY_CHECK_FAILED_ERROR = ErrorCode.DHKEY_CHECK_FAILED
|
||||||
SMP_DHKEY_CHECK_FAILED_ERROR: 'SMP_DHKEY_CHECK_FAILED_ERROR',
|
SMP_NUMERIC_COMPARISON_FAILED_ERROR = ErrorCode.NUMERIC_COMPARISON_FAILED
|
||||||
SMP_NUMERIC_COMPARISON_FAILED_ERROR: 'SMP_NUMERIC_COMPARISON_FAILED_ERROR',
|
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = ErrorCode.BD_EDR_PAIRING_IN_PROGRESS
|
||||||
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR: 'SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR',
|
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
|
||||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR: 'SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR'
|
|
||||||
}
|
|
||||||
|
|
||||||
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE = 0
|
class KeypressNotificationType(hci.SpecableEnum):
|
||||||
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE = 1
|
PASSKEY_ENTRY_STARTED = 0
|
||||||
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE = 2
|
PASSKEY_DIGIT_ENTERED = 1
|
||||||
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE = 3
|
PASSKEY_DIGIT_ERASED = 2
|
||||||
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE = 4
|
PASSKEY_CLEARED = 3
|
||||||
|
PASSKEY_ENTRY_COMPLETED = 4
|
||||||
SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES = {
|
|
||||||
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE',
|
|
||||||
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE',
|
|
||||||
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE',
|
|
||||||
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE',
|
|
||||||
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Bit flags for key distribution/generation
|
# Bit flags for key distribution/generation
|
||||||
SMP_ENC_KEY_DISTRIBUTION_FLAG = 0b0001
|
class KeyDistribution(hci.SpecableFlag):
|
||||||
SMP_ID_KEY_DISTRIBUTION_FLAG = 0b0010
|
ENC_KEY = 0b0001
|
||||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG = 0b0100
|
ID_KEY = 0b0010
|
||||||
SMP_LINK_KEY_DISTRIBUTION_FLAG = 0b1000
|
SIGN_KEY = 0b0100
|
||||||
|
LINK_KEY = 0b1000
|
||||||
|
|
||||||
# AuthReq fields
|
# AuthReq fields
|
||||||
SMP_BONDING_AUTHREQ = 0b00000001
|
class AuthReq(hci.SpecableFlag):
|
||||||
SMP_MITM_AUTHREQ = 0b00000100
|
BONDING = 0b00000001
|
||||||
SMP_SC_AUTHREQ = 0b00001000
|
MITM = 0b00000100
|
||||||
SMP_KEYPRESS_AUTHREQ = 0b00010000
|
SC = 0b00001000
|
||||||
SMP_CT2_AUTHREQ = 0b00100000
|
KEYPRESS = 0b00010000
|
||||||
|
CT2 = 0b00100000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_booleans(
|
||||||
|
cls,
|
||||||
|
bonding: bool = False,
|
||||||
|
sc: bool = False,
|
||||||
|
mitm: bool = False,
|
||||||
|
keypress: bool = False,
|
||||||
|
ct2: bool = False,
|
||||||
|
) -> AuthReq:
|
||||||
|
auth_req = AuthReq(0)
|
||||||
|
if bonding:
|
||||||
|
auth_req |= AuthReq.BONDING
|
||||||
|
if sc:
|
||||||
|
auth_req |= AuthReq.SC
|
||||||
|
if mitm:
|
||||||
|
auth_req |= AuthReq.MITM
|
||||||
|
if keypress:
|
||||||
|
auth_req |= AuthReq.KEYPRESS
|
||||||
|
if ct2:
|
||||||
|
auth_req |= AuthReq.CT2
|
||||||
|
return auth_req
|
||||||
|
|
||||||
# Crypto salt
|
# Crypto salt
|
||||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
||||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
||||||
|
|
||||||
|
# Diffie-Hellman private / public key pair in Debug Mode (Core - Vol. 3, Part H)
|
||||||
|
SMP_DEBUG_KEY_PRIVATE = bytes.fromhex(
|
||||||
|
'3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd'
|
||||||
|
)
|
||||||
|
SMP_DEBUG_KEY_PUBLIC_X = bytes.fromhex(
|
||||||
|
'20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
|
||||||
|
)
|
||||||
|
SMP_DEBUG_KEY_PUBLIC_Y= bytes.fromhex(
|
||||||
|
'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b'
|
||||||
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
@@ -188,8 +197,6 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Utils
|
# Utils
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def error_name(error_code: int) -> str:
|
|
||||||
return name_or_number(SMP_ERROR_NAMES, error_code)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -201,20 +208,22 @@ class SMP_Command:
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
|
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
|
||||||
'''
|
'''
|
||||||
|
|
||||||
smp_classes: ClassVar[dict[int, type[SMP_Command]]] = {}
|
smp_classes: ClassVar[dict[CommandCode, type[SMP_Command]]] = {}
|
||||||
fields: ClassVar[Fields]
|
fields: ClassVar[Fields]
|
||||||
code: int = field(default=0, init=False)
|
code: CommandCode = field(default=CommandCode(0), init=False)
|
||||||
name: str = field(default='', init=False)
|
name: str = field(default='', init=False)
|
||||||
_payload: bytes | None = field(default=None, init=False)
|
_payload: bytes | None = field(default=None, init=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, pdu: bytes) -> SMP_Command:
|
def from_bytes(cls, pdu: bytes) -> SMP_Command:
|
||||||
code = pdu[0]
|
if not pdu:
|
||||||
|
raise InvalidPacketError("Empty SMP PDU")
|
||||||
|
code = CommandCode(pdu[0])
|
||||||
|
|
||||||
subclass = SMP_Command.smp_classes.get(code)
|
subclass = SMP_Command.smp_classes.get(code)
|
||||||
if subclass is None:
|
if subclass is None:
|
||||||
instance = SMP_Command()
|
instance = SMP_Command()
|
||||||
instance.name = SMP_Command.command_name(code)
|
instance.name = code.name
|
||||||
instance.code = code
|
instance.code = code
|
||||||
instance.payload = pdu
|
instance.payload = pdu
|
||||||
return instance
|
return instance
|
||||||
@@ -222,59 +231,14 @@ class SMP_Command:
|
|||||||
instance.payload = pdu[1:]
|
instance.payload = pdu[1:]
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def command_name(code: int) -> str:
|
|
||||||
return name_or_number(SMP_COMMAND_NAMES, code)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def auth_req_str(value: int) -> str:
|
|
||||||
bonding_flags = value & 3
|
|
||||||
mitm = (value >> 2) & 1
|
|
||||||
sc = (value >> 3) & 1
|
|
||||||
keypress = (value >> 4) & 1
|
|
||||||
ct2 = (value >> 5) & 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
f'bonding_flags={bonding_flags}, '
|
|
||||||
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def io_capability_name(io_capability: int) -> str:
|
|
||||||
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def key_distribution_str(value: int) -> str:
|
|
||||||
key_types: list[str] = []
|
|
||||||
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
|
||||||
key_types.append('ENC')
|
|
||||||
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
|
||||||
key_types.append('ID')
|
|
||||||
if value & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
|
||||||
key_types.append('SIGN')
|
|
||||||
if value & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
|
||||||
key_types.append('LINK')
|
|
||||||
return ','.join(key_types)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def keypress_notification_type_name(notification_type: int) -> str:
|
|
||||||
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
|
|
||||||
|
|
||||||
_Command = TypeVar("_Command", bound="SMP_Command")
|
_Command = TypeVar("_Command", bound="SMP_Command")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
|
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
|
||||||
subclass.name = subclass.__name__.upper()
|
|
||||||
subclass.code = key_with_value(SMP_COMMAND_NAMES, subclass.name)
|
|
||||||
if subclass.code is None:
|
|
||||||
raise KeyError(
|
|
||||||
f'Command name {subclass.name} not found in SMP_COMMAND_NAMES'
|
|
||||||
)
|
|
||||||
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
|
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
|
||||||
|
subclass.name = subclass.__name__.upper()
|
||||||
# Register a factory for this class
|
# Register a factory for this class
|
||||||
SMP_Command.smp_classes[subclass.code] = subclass
|
SMP_Command.smp_classes[subclass.code] = subclass
|
||||||
|
|
||||||
return subclass
|
return subclass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -308,19 +272,17 @@ class SMP_Pairing_Request_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
|
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
|
||||||
'''
|
'''
|
||||||
|
|
||||||
io_capability: int = field(
|
code = CommandCode.PAIRING_REQUEST
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
|
|
||||||
)
|
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
|
||||||
oob_data_flag: int = field(metadata=metadata(1))
|
oob_data_flag: int = field(metadata=metadata(1))
|
||||||
auth_req: int = field(
|
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
|
||||||
)
|
|
||||||
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
||||||
initiator_key_distribution: int = field(
|
initiator_key_distribution: KeyDistribution = field(
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
metadata=KeyDistribution.type_metadata(1)
|
||||||
)
|
)
|
||||||
responder_key_distribution: int = field(
|
responder_key_distribution: KeyDistribution = field(
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
metadata=KeyDistribution.type_metadata(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -332,19 +294,17 @@ class SMP_Pairing_Response_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
|
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
|
||||||
'''
|
'''
|
||||||
|
|
||||||
io_capability: int = field(
|
code = CommandCode.PAIRING_RESPONSE
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
|
|
||||||
)
|
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
|
||||||
oob_data_flag: int = field(metadata=metadata(1))
|
oob_data_flag: int = field(metadata=metadata(1))
|
||||||
auth_req: int = field(
|
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
|
||||||
)
|
|
||||||
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
||||||
initiator_key_distribution: int = field(
|
initiator_key_distribution: KeyDistribution = field(
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
metadata=KeyDistribution.type_metadata(1)
|
||||||
)
|
)
|
||||||
responder_key_distribution: int = field(
|
responder_key_distribution: KeyDistribution = field(
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
metadata=KeyDistribution.type_metadata(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -356,6 +316,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
|
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.PAIRING_CONFIRM
|
||||||
|
|
||||||
confirm_value: bytes = field(metadata=metadata(16))
|
confirm_value: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -367,6 +329,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
|
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.PAIRING_RANDOM
|
||||||
|
|
||||||
random_value: bytes = field(metadata=metadata(16))
|
random_value: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -378,7 +342,9 @@ class SMP_Pairing_Failed_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
|
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
|
||||||
'''
|
'''
|
||||||
|
|
||||||
reason: int = field(metadata=metadata({'size': 1, 'mapper': error_name}))
|
code = CommandCode.PAIRING_FAILED
|
||||||
|
|
||||||
|
reason: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -389,6 +355,8 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
|
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.PAIRING_PUBLIC_KEY
|
||||||
|
|
||||||
public_key_x: bytes = field(metadata=metadata(32))
|
public_key_x: bytes = field(metadata=metadata(32))
|
||||||
public_key_y: bytes = field(metadata=metadata(32))
|
public_key_y: bytes = field(metadata=metadata(32))
|
||||||
|
|
||||||
@@ -401,6 +369,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
|
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.PAIRING_DHKEY_CHECK
|
||||||
|
|
||||||
dhkey_check: bytes = field(metadata=metadata(16))
|
dhkey_check: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -412,10 +382,10 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
|
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
|
||||||
'''
|
'''
|
||||||
|
|
||||||
notification_type: int = field(
|
code = CommandCode.PAIRING_KEYPRESS_NOTIFICATION
|
||||||
metadata=metadata(
|
|
||||||
{'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}
|
notification_type: KeypressNotificationType = field(
|
||||||
)
|
metadata=KeypressNotificationType.type_metadata(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -427,6 +397,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
|
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.ENCRYPTION_INFORMATION
|
||||||
|
|
||||||
long_term_key: bytes = field(metadata=metadata(16))
|
long_term_key: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -438,6 +410,8 @@ class SMP_Master_Identification_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
|
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.MASTER_IDENTIFICATION
|
||||||
|
|
||||||
ediv: int = field(metadata=metadata(2))
|
ediv: int = field(metadata=metadata(2))
|
||||||
rand: bytes = field(metadata=metadata(8))
|
rand: bytes = field(metadata=metadata(8))
|
||||||
|
|
||||||
@@ -450,6 +424,8 @@ class SMP_Identity_Information_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
|
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.IDENTITY_INFORMATION
|
||||||
|
|
||||||
identity_resolving_key: bytes = field(metadata=metadata(16))
|
identity_resolving_key: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -461,6 +437,8 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
|
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.IDENTITY_ADDRESS_INFORMATION
|
||||||
|
|
||||||
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
|
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
|
||||||
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
|
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
|
||||||
|
|
||||||
@@ -473,6 +451,8 @@ class SMP_Signing_Information_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
|
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
code = CommandCode.SIGNING_INFORMATION
|
||||||
|
|
||||||
signature_key: bytes = field(metadata=metadata(16))
|
signature_key: bytes = field(metadata=metadata(16))
|
||||||
|
|
||||||
|
|
||||||
@@ -484,25 +464,9 @@ class SMP_Security_Request_Command(SMP_Command):
|
|||||||
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
|
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
|
||||||
'''
|
'''
|
||||||
|
|
||||||
auth_req: int = field(
|
code = CommandCode.SECURITY_REQUEST
|
||||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
|
|
||||||
value = 0
|
|
||||||
if bonding:
|
|
||||||
value |= SMP_BONDING_AUTHREQ
|
|
||||||
if mitm:
|
|
||||||
value |= SMP_MITM_AUTHREQ
|
|
||||||
if sc:
|
|
||||||
value |= SMP_SC_AUTHREQ
|
|
||||||
if keypress:
|
|
||||||
value |= SMP_KEYPRESS_AUTHREQ
|
|
||||||
if ct2:
|
|
||||||
value |= SMP_CT2_AUTHREQ
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -676,8 +640,8 @@ class Session:
|
|||||||
self.ltk_rand = bytes(8)
|
self.ltk_rand = bytes(8)
|
||||||
self.link_key: bytes | None = None
|
self.link_key: bytes | None = None
|
||||||
self.maximum_encryption_key_size: int = 0
|
self.maximum_encryption_key_size: int = 0
|
||||||
self.initiator_key_distribution: int = 0
|
self.initiator_key_distribution: KeyDistribution = KeyDistribution(0)
|
||||||
self.responder_key_distribution: int = 0
|
self.responder_key_distribution: KeyDistribution = KeyDistribution(0)
|
||||||
self.peer_random_value: bytes | None = None
|
self.peer_random_value: bytes | None = None
|
||||||
self.peer_public_key_x: bytes = bytes(32)
|
self.peer_public_key_x: bytes = bytes(32)
|
||||||
self.peer_public_key_y = bytes(32)
|
self.peer_public_key_y = bytes(32)
|
||||||
@@ -728,10 +692,10 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Key Distribution (default values before negotiation)
|
# Key Distribution (default values before negotiation)
|
||||||
self.initiator_key_distribution = (
|
self.initiator_key_distribution = KeyDistribution(
|
||||||
pairing_config.delegate.local_initiator_key_distribution
|
pairing_config.delegate.local_initiator_key_distribution
|
||||||
)
|
)
|
||||||
self.responder_key_distribution = (
|
self.responder_key_distribution = KeyDistribution(
|
||||||
pairing_config.delegate.local_responder_key_distribution
|
pairing_config.delegate.local_responder_key_distribution
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -743,7 +707,7 @@ class Session:
|
|||||||
self.ct2: bool = False
|
self.ct2: bool = False
|
||||||
|
|
||||||
# I/O Capabilities
|
# I/O Capabilities
|
||||||
self.io_capability = pairing_config.delegate.io_capability
|
self.io_capability = IoCapability(pairing_config.delegate.io_capability)
|
||||||
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||||
|
|
||||||
# OOB
|
# OOB
|
||||||
@@ -822,8 +786,14 @@ class Session:
|
|||||||
return self.nx[0 if self.is_responder else 1]
|
return self.nx[0 if self.is_responder else 1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_req(self) -> int:
|
def auth_req(self) -> AuthReq:
|
||||||
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
|
return AuthReq.from_booleans(
|
||||||
|
bonding=self.bonding,
|
||||||
|
sc=self.sc,
|
||||||
|
mitm=self.mitm,
|
||||||
|
keypress=self.keypress,
|
||||||
|
ct2=self.ct2,
|
||||||
|
)
|
||||||
|
|
||||||
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
|
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
|
||||||
if not self.sc and not self.completed:
|
if not self.sc and not self.completed:
|
||||||
@@ -843,7 +813,7 @@ class Session:
|
|||||||
if self.connection.transport == PhysicalTransport.BR_EDR:
|
if self.connection.transport == PhysicalTransport.BR_EDR:
|
||||||
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
||||||
return
|
return
|
||||||
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
|
if (not self.mitm) and (auth_req & AuthReq.MITM == 0):
|
||||||
self.pairing_method = PairingMethod.JUST_WORKS
|
self.pairing_method = PairingMethod.JUST_WORKS
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -861,7 +831,7 @@ class Session:
|
|||||||
self.passkey_display = details[1 if self.is_initiator else 2]
|
self.passkey_display = details[1 if self.is_initiator else 2]
|
||||||
|
|
||||||
def check_expected_value(
|
def check_expected_value(
|
||||||
self, expected: bytes, received: bytes, error: int
|
self, expected: bytes, received: bytes, error: ErrorCode
|
||||||
) -> bool:
|
) -> bool:
|
||||||
logger.debug(f'expected={expected.hex()} got={received.hex()}')
|
logger.debug(f'expected={expected.hex()} got={received.hex()}')
|
||||||
if expected != received:
|
if expected != received:
|
||||||
@@ -881,7 +851,7 @@ class Session:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('exception while confirm')
|
logger.exception('exception while confirm')
|
||||||
|
|
||||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
|
||||||
|
|
||||||
self.connection.cancel_on_disconnection(prompt())
|
self.connection.cancel_on_disconnection(prompt())
|
||||||
|
|
||||||
@@ -900,7 +870,7 @@ class Session:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('exception while prompting')
|
logger.exception('exception while prompting')
|
||||||
|
|
||||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
|
||||||
|
|
||||||
self.connection.cancel_on_disconnection(prompt())
|
self.connection.cancel_on_disconnection(prompt())
|
||||||
|
|
||||||
@@ -911,13 +881,13 @@ class Session:
|
|||||||
passkey = await self.pairing_config.delegate.get_number()
|
passkey = await self.pairing_config.delegate.get_number()
|
||||||
if passkey is None:
|
if passkey is None:
|
||||||
logger.debug('Passkey request rejected')
|
logger.debug('Passkey request rejected')
|
||||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
|
||||||
return
|
return
|
||||||
logger.debug(f'user input: {passkey}')
|
logger.debug(f'user input: {passkey}')
|
||||||
next_steps(passkey)
|
next_steps(passkey)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('exception while prompting')
|
logger.exception('exception while prompting')
|
||||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
|
||||||
|
|
||||||
self.connection.cancel_on_disconnection(prompt())
|
self.connection.cancel_on_disconnection(prompt())
|
||||||
|
|
||||||
@@ -972,7 +942,7 @@ class Session:
|
|||||||
def send_command(self, command: SMP_Command) -> None:
|
def send_command(self, command: SMP_Command) -> None:
|
||||||
self.manager.send_command(self.connection, command)
|
self.manager.send_command(self.connection, command)
|
||||||
|
|
||||||
def send_pairing_failed(self, error: int) -> None:
|
def send_pairing_failed(self, error: ErrorCode) -> None:
|
||||||
self.send_command(SMP_Pairing_Failed_Command(reason=error))
|
self.send_command(SMP_Pairing_Failed_Command(reason=error))
|
||||||
self.on_pairing_failure(error)
|
self.on_pairing_failure(error)
|
||||||
|
|
||||||
@@ -1144,7 +1114,7 @@ class Session:
|
|||||||
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
||||||
)
|
)
|
||||||
self.send_pairing_failed(
|
self.send_pairing_failed(
|
||||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
|
ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
||||||
@@ -1155,14 +1125,14 @@ class Session:
|
|||||||
# CTKD: Derive LTK from LinkKey
|
# CTKD: Derive LTK from LinkKey
|
||||||
if (
|
if (
|
||||||
self.connection.transport == PhysicalTransport.BR_EDR
|
self.connection.transport == PhysicalTransport.BR_EDR
|
||||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
and self.initiator_key_distribution & KeyDistribution.ENC_KEY
|
||||||
):
|
):
|
||||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||||
self.get_link_key_and_derive_ltk()
|
self.get_link_key_and_derive_ltk()
|
||||||
)
|
)
|
||||||
elif not self.sc:
|
elif not self.sc:
|
||||||
# Distribute the LTK, EDIV and RAND
|
# Distribute the LTK, EDIV and RAND
|
||||||
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
if self.initiator_key_distribution & KeyDistribution.ENC_KEY:
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
||||||
)
|
)
|
||||||
@@ -1173,7 +1143,7 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Distribute IRK & BD ADDR
|
# Distribute IRK & BD ADDR
|
||||||
if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
if self.initiator_key_distribution & KeyDistribution.ID_KEY:
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Identity_Information_Command(
|
SMP_Identity_Information_Command(
|
||||||
identity_resolving_key=self.manager.device.irk
|
identity_resolving_key=self.manager.device.irk
|
||||||
@@ -1183,25 +1153,25 @@ class Session:
|
|||||||
|
|
||||||
# Distribute CSRK
|
# Distribute CSRK
|
||||||
csrk = bytes(16) # FIXME: testing
|
csrk = bytes(16) # FIXME: testing
|
||||||
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
if self.initiator_key_distribution & KeyDistribution.SIGN_KEY:
|
||||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||||
|
|
||||||
# CTKD, calculate BR/EDR link key
|
# CTKD, calculate BR/EDR link key
|
||||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
if self.initiator_key_distribution & KeyDistribution.LINK_KEY:
|
||||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# CTKD: Derive LTK from LinkKey
|
# CTKD: Derive LTK from LinkKey
|
||||||
if (
|
if (
|
||||||
self.connection.transport == PhysicalTransport.BR_EDR
|
self.connection.transport == PhysicalTransport.BR_EDR
|
||||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
and self.responder_key_distribution & KeyDistribution.ENC_KEY
|
||||||
):
|
):
|
||||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||||
self.get_link_key_and_derive_ltk()
|
self.get_link_key_and_derive_ltk()
|
||||||
)
|
)
|
||||||
# Distribute the LTK, EDIV and RAND
|
# Distribute the LTK, EDIV and RAND
|
||||||
elif not self.sc:
|
elif not self.sc:
|
||||||
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
if self.responder_key_distribution & KeyDistribution.ENC_KEY:
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
||||||
)
|
)
|
||||||
@@ -1212,7 +1182,7 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Distribute IRK & BD ADDR
|
# Distribute IRK & BD ADDR
|
||||||
if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
if self.responder_key_distribution & KeyDistribution.ID_KEY:
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Identity_Information_Command(
|
SMP_Identity_Information_Command(
|
||||||
identity_resolving_key=self.manager.device.irk
|
identity_resolving_key=self.manager.device.irk
|
||||||
@@ -1222,30 +1192,30 @@ class Session:
|
|||||||
|
|
||||||
# Distribute CSRK
|
# Distribute CSRK
|
||||||
csrk = bytes(16) # FIXME: testing
|
csrk = bytes(16) # FIXME: testing
|
||||||
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
if self.responder_key_distribution & KeyDistribution.SIGN_KEY:
|
||||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||||
|
|
||||||
# CTKD, calculate BR/EDR link key
|
# CTKD, calculate BR/EDR link key
|
||||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
if self.responder_key_distribution & KeyDistribution.LINK_KEY:
|
||||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||||
|
|
||||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||||
# Set our expectations for what to wait for in the key distribution phase
|
# Set our expectations for what to wait for in the key distribution phase
|
||||||
self.peer_expected_distributions = []
|
self.peer_expected_distributions = []
|
||||||
if not self.sc and self.connection.transport == PhysicalTransport.LE:
|
if not self.sc and self.connection.transport == PhysicalTransport.LE:
|
||||||
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
|
if key_distribution_flags & KeyDistribution.ENC_KEY != 0:
|
||||||
self.peer_expected_distributions.append(
|
self.peer_expected_distributions.append(
|
||||||
SMP_Encryption_Information_Command
|
SMP_Encryption_Information_Command
|
||||||
)
|
)
|
||||||
self.peer_expected_distributions.append(
|
self.peer_expected_distributions.append(
|
||||||
SMP_Master_Identification_Command
|
SMP_Master_Identification_Command
|
||||||
)
|
)
|
||||||
if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0:
|
if key_distribution_flags & KeyDistribution.ID_KEY != 0:
|
||||||
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
|
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
|
||||||
self.peer_expected_distributions.append(
|
self.peer_expected_distributions.append(
|
||||||
SMP_Identity_Address_Information_Command
|
SMP_Identity_Address_Information_Command
|
||||||
)
|
)
|
||||||
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
|
if key_distribution_flags & KeyDistribution.SIGN_KEY != 0:
|
||||||
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
|
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'expecting distributions: '
|
'expecting distributions: '
|
||||||
@@ -1258,7 +1228,7 @@ class Session:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
color('received key distribution on a non-encrypted connection', 'red')
|
color('received key distribution on a non-encrypted connection', 'red')
|
||||||
)
|
)
|
||||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check that this command class is expected
|
# Check that this command class is expected
|
||||||
@@ -1278,7 +1248,7 @@ class Session:
|
|||||||
'red',
|
'red',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
|
||||||
|
|
||||||
async def pair(self) -> None:
|
async def pair(self) -> None:
|
||||||
# Start pairing as an initiator
|
# Start pairing as an initiator
|
||||||
@@ -1389,34 +1359,56 @@ class Session:
|
|||||||
)
|
)
|
||||||
await self.manager.on_pairing(self, peer_address, keys)
|
await self.manager.on_pairing(self, peer_address, keys)
|
||||||
|
|
||||||
def on_pairing_failure(self, reason: int) -> None:
|
def on_pairing_failure(self, reason: ErrorCode) -> None:
|
||||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
logger.warning('pairing failure (%s)', reason.name)
|
||||||
|
|
||||||
if self.completed:
|
if self.completed:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.completed = True
|
self.completed = True
|
||||||
|
|
||||||
error = ProtocolError(reason, 'smp', error_name(reason))
|
error = ProtocolError(reason, 'smp', reason.name)
|
||||||
if self.pairing_result is not None and not self.pairing_result.done():
|
if self.pairing_result is not None and not self.pairing_result.done():
|
||||||
self.pairing_result.set_exception(error)
|
self.pairing_result.set_exception(error)
|
||||||
self.manager.on_pairing_failure(self, reason)
|
self.manager.on_pairing_failure(self, reason)
|
||||||
|
|
||||||
def on_smp_command(self, command: SMP_Command) -> None:
|
def on_smp_command(self, command: SMP_Command) -> None:
|
||||||
# Find the handler method
|
try:
|
||||||
handler_name = f'on_{command.name.lower()}'
|
match command:
|
||||||
handler = getattr(self, handler_name, None)
|
case SMP_Pairing_Request_Command():
|
||||||
if handler is not None:
|
self.on_smp_pairing_request_command(command)
|
||||||
try:
|
case SMP_Pairing_Response_Command():
|
||||||
handler(command)
|
self.on_smp_pairing_response_command(command)
|
||||||
except Exception:
|
case SMP_Pairing_Confirm_Command():
|
||||||
logger.exception(color("!!! Exception in handler:", "red"))
|
self.on_smp_pairing_confirm_command(command)
|
||||||
response = SMP_Pairing_Failed_Command(
|
case SMP_Pairing_Random_Command():
|
||||||
reason=SMP_UNSPECIFIED_REASON_ERROR
|
self.on_smp_pairing_random_command(command)
|
||||||
)
|
case SMP_Pairing_Failed_Command():
|
||||||
self.send_command(response)
|
self.on_smp_pairing_failed_command(command)
|
||||||
else:
|
case SMP_Encryption_Information_Command():
|
||||||
logger.error(color('SMP command not handled???', 'red'))
|
self.on_smp_encryption_information_command(command)
|
||||||
|
case SMP_Master_Identification_Command():
|
||||||
|
self.on_smp_master_identification_command(command)
|
||||||
|
case SMP_Identity_Information_Command():
|
||||||
|
self.on_smp_identity_information_command(command)
|
||||||
|
case SMP_Identity_Address_Information_Command():
|
||||||
|
self.on_smp_identity_address_information_command(command)
|
||||||
|
case SMP_Signing_Information_Command():
|
||||||
|
self.on_smp_signing_information_command(command)
|
||||||
|
case SMP_Pairing_Public_Key_Command():
|
||||||
|
self.on_smp_pairing_public_key_command(command)
|
||||||
|
case SMP_Pairing_DHKey_Check_Command():
|
||||||
|
self.on_smp_pairing_dhkey_check_command(command)
|
||||||
|
# case SMP_Security_Request_Command():
|
||||||
|
# self.on_smp_security_request_command(command)
|
||||||
|
# case SMP_Pairing_Keypress_Notification_Command():
|
||||||
|
# self.on_smp_pairing_keypress_notification_command(command)
|
||||||
|
case _:
|
||||||
|
logger.error(color('SMP command not handled', 'red'))
|
||||||
|
except Exception:
|
||||||
|
logger.exception(color("!!! Exception in handler:", "red"))
|
||||||
|
response = SMP_Pairing_Failed_Command(reason=ErrorCode.UNSPECIFIED_REASON)
|
||||||
|
self.send_command(response)
|
||||||
|
|
||||||
def on_smp_pairing_request_command(
|
def on_smp_pairing_request_command(
|
||||||
self, command: SMP_Pairing_Request_Command
|
self, command: SMP_Pairing_Request_Command
|
||||||
@@ -1436,16 +1428,16 @@ class Session:
|
|||||||
accepted = False
|
accepted = False
|
||||||
if not accepted:
|
if not accepted:
|
||||||
logger.debug('pairing rejected by delegate')
|
logger.debug('pairing rejected by delegate')
|
||||||
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
|
self.send_pairing_failed(ErrorCode.PAIRING_NOT_SUPPORTED)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Save the request
|
# Save the request
|
||||||
self.preq = bytes(command)
|
self.preq = bytes(command)
|
||||||
|
|
||||||
# Bonding and SC require both sides to request/support it
|
# Bonding and SC require both sides to request/support it
|
||||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
|
||||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
|
||||||
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
|
self.ct2 = self.ct2 and (command.auth_req & AuthReq.CT2 != 0)
|
||||||
|
|
||||||
# Infer the pairing method
|
# Infer the pairing method
|
||||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||||
@@ -1456,7 +1448,7 @@ class Session:
|
|||||||
if not self.sc and self.tk is None:
|
if not self.sc and self.tk is None:
|
||||||
# For legacy OOB, TK is required.
|
# For legacy OOB, TK is required.
|
||||||
logger.warning("legacy OOB without TK")
|
logger.warning("legacy OOB without TK")
|
||||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
|
||||||
return
|
return
|
||||||
if command.oob_data_flag == 0:
|
if command.oob_data_flag == 0:
|
||||||
# The peer doesn't have OOB data, use r=0
|
# The peer doesn't have OOB data, use r=0
|
||||||
@@ -1475,8 +1467,11 @@ class Session:
|
|||||||
(
|
(
|
||||||
self.initiator_key_distribution,
|
self.initiator_key_distribution,
|
||||||
self.responder_key_distribution,
|
self.responder_key_distribution,
|
||||||
) = await self.pairing_config.delegate.key_distribution_response(
|
) = map(
|
||||||
command.initiator_key_distribution, command.responder_key_distribution
|
KeyDistribution,
|
||||||
|
await self.pairing_config.delegate.key_distribution_response(
|
||||||
|
command.initiator_key_distribution, command.responder_key_distribution
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.compute_peer_expected_distributions(self.initiator_key_distribution)
|
self.compute_peer_expected_distributions(self.initiator_key_distribution)
|
||||||
|
|
||||||
@@ -1514,8 +1509,8 @@ class Session:
|
|||||||
self.peer_io_capability = command.io_capability
|
self.peer_io_capability = command.io_capability
|
||||||
|
|
||||||
# Bonding and SC require both sides to request/support it
|
# Bonding and SC require both sides to request/support it
|
||||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
|
||||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
|
||||||
|
|
||||||
# Infer the pairing method
|
# Infer the pairing method
|
||||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||||
@@ -1526,7 +1521,7 @@ class Session:
|
|||||||
if not self.sc and self.tk is None:
|
if not self.sc and self.tk is None:
|
||||||
# For legacy OOB, TK is required.
|
# For legacy OOB, TK is required.
|
||||||
logger.warning("legacy OOB without TK")
|
logger.warning("legacy OOB without TK")
|
||||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
|
||||||
return
|
return
|
||||||
if command.oob_data_flag == 0:
|
if command.oob_data_flag == 0:
|
||||||
# The peer doesn't have OOB data, use r=0
|
# The peer doesn't have OOB data, use r=0
|
||||||
@@ -1546,7 +1541,7 @@ class Session:
|
|||||||
command.responder_key_distribution & ~self.responder_key_distribution != 0
|
command.responder_key_distribution & ~self.responder_key_distribution != 0
|
||||||
):
|
):
|
||||||
# The response isn't a subset of the request
|
# The response isn't a subset of the request
|
||||||
self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR)
|
self.send_pairing_failed(ErrorCode.INVALID_PARAMETERS)
|
||||||
return
|
return
|
||||||
self.initiator_key_distribution = command.initiator_key_distribution
|
self.initiator_key_distribution = command.initiator_key_distribution
|
||||||
self.responder_key_distribution = command.responder_key_distribution
|
self.responder_key_distribution = command.responder_key_distribution
|
||||||
@@ -1624,7 +1619,7 @@ class Session:
|
|||||||
)
|
)
|
||||||
assert self.confirm_value
|
assert self.confirm_value
|
||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1665,7 +1660,7 @@ class Session:
|
|||||||
self.pkb, self.pka, command.random_value, bytes([0])
|
self.pkb, self.pka, command.random_value, bytes([0])
|
||||||
)
|
)
|
||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||||
@@ -1678,7 +1673,7 @@ class Session:
|
|||||||
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
||||||
)
|
)
|
||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1707,7 +1702,7 @@ class Session:
|
|||||||
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
||||||
)
|
)
|
||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1824,7 +1819,7 @@ class Session:
|
|||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
self.peer_oob_data.c,
|
self.peer_oob_data.c,
|
||||||
confirm_verifier,
|
confirm_verifier,
|
||||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
ErrorCode.CONFIRM_VALUE_FAILED,
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1858,7 +1853,7 @@ class Session:
|
|||||||
expected = self.eb if self.is_initiator else self.ea
|
expected = self.eb if self.is_initiator else self.ea
|
||||||
assert expected
|
assert expected
|
||||||
if not self.check_expected_value(
|
if not self.check_expected_value(
|
||||||
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
|
expected, command.dhkey_check, ErrorCode.DHKEY_CHECK_FAILED
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1937,6 +1932,7 @@ class Manager(utils.EventEmitter):
|
|||||||
self._ecc_key = None
|
self._ecc_key = None
|
||||||
self.pairing_config_factory = pairing_config_factory
|
self.pairing_config_factory = pairing_config_factory
|
||||||
self.session_proxy = Session
|
self.session_proxy = Session
|
||||||
|
self.debug_mode = False
|
||||||
|
|
||||||
def send_command(self, connection: Connection, command: SMP_Command) -> None:
|
def send_command(self, connection: Connection, command: SMP_Command) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1962,7 +1958,7 @@ class Manager(utils.EventEmitter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Security request is more than just pairing, so let applications handle them
|
# Security request is more than just pairing, so let applications handle them
|
||||||
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
if command.code == CommandCode.SECURITY_REQUEST:
|
||||||
self.on_smp_security_request_command(
|
self.on_smp_security_request_command(
|
||||||
connection, cast(SMP_Security_Request_Command, command)
|
connection, cast(SMP_Security_Request_Command, command)
|
||||||
)
|
)
|
||||||
@@ -1983,6 +1979,13 @@ class Manager(utils.EventEmitter):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def ecc_key(self) -> crypto.EccKey:
|
def ecc_key(self) -> crypto.EccKey:
|
||||||
|
if self.debug_mode:
|
||||||
|
# Core - Vol 3, Part H:
|
||||||
|
# When the Security Manager is placed in a Debug mode it shall use the
|
||||||
|
# following Diffie-Hellman private / public key pair:
|
||||||
|
debug_key = crypto.EccKey.from_private_key_bytes(SMP_DEBUG_KEY_PRIVATE)
|
||||||
|
return debug_key
|
||||||
|
|
||||||
if self._ecc_key is None:
|
if self._ecc_key is None:
|
||||||
self._ecc_key = crypto.EccKey.generate()
|
self._ecc_key = crypto.EccKey.generate()
|
||||||
assert self._ecc_key
|
assert self._ecc_key
|
||||||
@@ -2002,15 +2005,13 @@ class Manager(utils.EventEmitter):
|
|||||||
def request_pairing(self, connection: Connection) -> None:
|
def request_pairing(self, connection: Connection) -> None:
|
||||||
pairing_config = self.pairing_config_factory(connection)
|
pairing_config = self.pairing_config_factory(connection)
|
||||||
if pairing_config:
|
if pairing_config:
|
||||||
auth_req = smp_auth_req(
|
auth_req = AuthReq.from_booleans(
|
||||||
pairing_config.bonding,
|
bonding=pairing_config.bonding,
|
||||||
pairing_config.mitm,
|
sc=pairing_config.sc,
|
||||||
pairing_config.sc,
|
mitm=pairing_config.mitm,
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
auth_req = 0
|
auth_req = AuthReq(0)
|
||||||
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
|
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
|
||||||
|
|
||||||
def on_session_start(self, session: Session) -> None:
|
def on_session_start(self, session: Session) -> None:
|
||||||
@@ -2026,7 +2027,7 @@ class Manager(utils.EventEmitter):
|
|||||||
# Notify the device
|
# Notify the device
|
||||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||||
|
|
||||||
def on_pairing_failure(self, session: Session, reason: int) -> None:
|
def on_pairing_failure(self, session: Session, reason: ErrorCode) -> None:
|
||||||
self.device.on_pairing_failure(session.connection, reason)
|
self.device.on_pairing_failure(session.connection, reason)
|
||||||
|
|
||||||
def on_session_end(self, session: Session) -> None:
|
def on_session_end(self, session: Session) -> None:
|
||||||
|
|||||||
@@ -133,10 +133,10 @@ def on_avrcp_start(
|
|||||||
utils.AsyncRunner.spawn(get_supported_events())
|
utils.AsyncRunner.spawn(get_supported_events())
|
||||||
|
|
||||||
async def monitor_track_changed() -> None:
|
async def monitor_track_changed() -> None:
|
||||||
async for identifier in avrcp_protocol.monitor_track_changed():
|
async for uid in avrcp_protocol.monitor_track_changed():
|
||||||
print("TRACK CHANGED:", identifier.hex())
|
print("TRACK CHANGED:", hex(uid))
|
||||||
websocket_server.send_message(
|
websocket_server.send_message(
|
||||||
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
|
{"type": "track-changed", "params": {"identifier": hex(uid)}}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def monitor_playback_status() -> None:
|
async def monitor_playback_status() -> None:
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ async def main() -> None:
|
|||||||
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
|
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
|
||||||
)
|
)
|
||||||
server_device.add_service(device_info_service)
|
server_device.add_service(device_info_service)
|
||||||
|
await server_device.start_advertising()
|
||||||
|
|
||||||
# Connect the client to the server
|
# Connect the client to the server
|
||||||
connection = await client_device.connect(server_device.random_address)
|
connection = await client_device.connect(server_device.random_address)
|
||||||
|
|||||||
+2
-3
@@ -13,13 +13,12 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }]
|
|||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp ~= 3.8; platform_system!='Emscripten'",
|
"aiohttp ~= 3.8; platform_system!='Emscripten'",
|
||||||
"appdirs >= 1.4; platform_system!='Emscripten'",
|
|
||||||
"click >= 8.1.3; platform_system!='Emscripten'",
|
"click >= 8.1.3; platform_system!='Emscripten'",
|
||||||
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
|
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
|
||||||
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
|
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
|
||||||
# versions available on PyPI. Relax the version requirement since it's better than being
|
# versions available on PyPI. Relax the version requirement since it's better than being
|
||||||
# completely unable to import the package in case of version mismatch.
|
# completely unable to import the package in case of version mismatch.
|
||||||
"cryptography >= 44.0.3; platform_system=='Emscripten'",
|
"cryptography >= 39.0.0; platform_system=='Emscripten'",
|
||||||
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
|
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
|
||||||
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
|
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
|
||||||
# updated. Relax the version requirement since it's better than being completely unable
|
# updated. Relax the version requirement since it's better than being completely unable
|
||||||
@@ -37,7 +36,7 @@ dependencies = [
|
|||||||
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
|
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
|
||||||
"pyserial >= 3.5; platform_system!='Emscripten'",
|
"pyserial >= 3.5; platform_system!='Emscripten'",
|
||||||
"pyusb >= 1.2; platform_system!='Emscripten'",
|
"pyusb >= 1.2; platform_system!='Emscripten'",
|
||||||
"tomli ~= 2.2.1; platform_system!='Emscripten'",
|
"tomli ~= 2.2.1; platform_system!='Emscripten' and python_version<'3.11'",
|
||||||
"websockets >= 15.0.1; platform_system!='Emscripten'",
|
"websockets >= 15.0.1; platform_system!='Emscripten'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Generated
+2
-2
@@ -221,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytes"
|
name = "bytes"
|
||||||
version = "1.5.0"
|
version = "1.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
|
|||||||
+1
-1
@@ -30,7 +30,7 @@ hex = "0.4.3"
|
|||||||
itertools = "0.11.0"
|
itertools = "0.11.0"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
thiserror = "1.0.41"
|
thiserror = "1.0.41"
|
||||||
bytes = "1.5.0"
|
bytes = "1.11.1"
|
||||||
pdl-derive = "0.2.0"
|
pdl-derive = "0.2.0"
|
||||||
pdl-runtime = "0.2.0"
|
pdl-runtime = "0.2.0"
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
|
|||||||
@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
|
|||||||
assert message.payload == parsed.payload
|
assert message.payload == parsed.payload
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'pdu',
|
||||||
|
(
|
||||||
|
b'', # empty PDU — would IndexError on pdu[0]
|
||||||
|
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
|
||||||
|
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
|
||||||
|
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_message_assembler_truncated_pdu(pdu: bytes):
|
||||||
|
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
|
||||||
|
same DoS class as #912 (ATT empty PDU). The assembler is required to
|
||||||
|
log + drop and stay alive so the L2CAP channel survives."""
|
||||||
|
completed = []
|
||||||
|
|
||||||
|
def callback(transaction_label, message):
|
||||||
|
completed.append((transaction_label, message))
|
||||||
|
|
||||||
|
assembler = avdtp.MessageAssembler(callback)
|
||||||
|
# Must not raise; nothing should be delivered to callback either.
|
||||||
|
assembler.on_pdu(pdu)
|
||||||
|
assert not completed
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_rtp():
|
def test_rtp():
|
||||||
packet = bytes.fromhex(
|
packet = bytes.fromhex(
|
||||||
|
|||||||
+179
-3
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
import struct
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -118,8 +119,6 @@ class TwoDevices(test_utils.TwoDevices):
|
|||||||
scope=avrcp.Scope.NOW_PLAYING,
|
scope=avrcp.Scope.NOW_PLAYING,
|
||||||
uid=0,
|
uid=0,
|
||||||
uid_counter=1,
|
uid_counter=1,
|
||||||
start_item=0,
|
|
||||||
end_item=0,
|
|
||||||
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
|
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
|
||||||
),
|
),
|
||||||
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
|
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
|
||||||
@@ -136,7 +135,7 @@ def test_command(command: avrcp.Command):
|
|||||||
"event,",
|
"event,",
|
||||||
[
|
[
|
||||||
avrcp.UidsChangedEvent(uid_counter=7),
|
avrcp.UidsChangedEvent(uid_counter=7),
|
||||||
avrcp.TrackChangedEvent(identifier=b'12356'),
|
avrcp.TrackChangedEvent(uid=12356),
|
||||||
avrcp.VolumeChangedEvent(volume=9),
|
avrcp.VolumeChangedEvent(volume=9),
|
||||||
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
|
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
|
||||||
avrcp.AddressedPlayerChangedEvent(
|
avrcp.AddressedPlayerChangedEvent(
|
||||||
@@ -581,6 +580,87 @@ async def test_get_supported_company_ids():
|
|||||||
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
|
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_player_application_settings():
|
||||||
|
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
expected_settings = {
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: [
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.SINGLE_TRACK_REPEAT,
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.OFF,
|
||||||
|
],
|
||||||
|
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: [
|
||||||
|
avrcp.ApplicationSetting.ShuffleOnOffStatus.OFF,
|
||||||
|
avrcp.ApplicationSetting.ShuffleOnOffStatus.ALL_TRACKS_SHUFFLE,
|
||||||
|
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||||
|
supported_player_app_settings=expected_settings
|
||||||
|
)
|
||||||
|
actual_settings = await two_devices.protocols[
|
||||||
|
0
|
||||||
|
].list_supported_player_app_settings()
|
||||||
|
assert actual_settings == expected_settings
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_set_player_app_settings():
|
||||||
|
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate
|
||||||
|
await two_devices.protocols[0].send_avrcp_command(
|
||||||
|
avc.CommandFrame.CommandType.CONTROL,
|
||||||
|
avrcp.SetPlayerApplicationSettingValueCommand(
|
||||||
|
attribute=[
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||||
|
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||||
|
],
|
||||||
|
value=[
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||||
|
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
expected_settings = {
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||||
|
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||||
|
}
|
||||||
|
assert delegate.player_app_settings == expected_settings
|
||||||
|
|
||||||
|
actual_settings = await two_devices.protocols[0].get_player_app_settings(
|
||||||
|
[
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||||
|
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert actual_settings == expected_settings
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_play_item():
|
||||||
|
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate
|
||||||
|
|
||||||
|
with mock.patch.object(delegate, delegate.play_item.__name__) as play_item_mock:
|
||||||
|
await two_devices.protocols[0].send_avrcp_command(
|
||||||
|
avc.CommandFrame.CommandType.CONTROL,
|
||||||
|
avrcp.PlayItemCommand(
|
||||||
|
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
play_item_mock.assert_called_once_with(
|
||||||
|
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_monitor_volume():
|
async def test_monitor_volume():
|
||||||
@@ -635,6 +715,102 @@ async def test_monitor_now_playing_content():
|
|||||||
await anext(now_playing_iter)
|
await anext(now_playing_iter)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_track_changed():
|
||||||
|
two_devices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||||
|
[avrcp.EventId.TRACK_CHANGED]
|
||||||
|
)
|
||||||
|
delegate.current_track_uid = avrcp.TrackChangedEvent.NO_TRACK
|
||||||
|
track_iter = two_devices.protocols[0].monitor_track_changed()
|
||||||
|
|
||||||
|
# Interim
|
||||||
|
assert (await anext(track_iter)) == avrcp.TrackChangedEvent.NO_TRACK
|
||||||
|
# Changed
|
||||||
|
two_devices.protocols[1].notify_track_changed(1)
|
||||||
|
assert (await anext(track_iter)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_uid_changed():
|
||||||
|
two_devices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||||
|
[avrcp.EventId.UIDS_CHANGED]
|
||||||
|
)
|
||||||
|
delegate.uid_counter = 0
|
||||||
|
uid_iter = two_devices.protocols[0].monitor_uids()
|
||||||
|
|
||||||
|
# Interim
|
||||||
|
assert (await anext(uid_iter)) == 0
|
||||||
|
# Changed
|
||||||
|
two_devices.protocols[1].notify_uids_changed(1)
|
||||||
|
assert (await anext(uid_iter)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_addressed_player():
|
||||||
|
two_devices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||||
|
[avrcp.EventId.ADDRESSED_PLAYER_CHANGED]
|
||||||
|
)
|
||||||
|
delegate.uid_counter = 0
|
||||||
|
delegate.addressed_player_id = 0
|
||||||
|
addressed_player_iter = two_devices.protocols[0].monitor_addressed_player()
|
||||||
|
|
||||||
|
# Interim
|
||||||
|
assert (
|
||||||
|
await anext(addressed_player_iter)
|
||||||
|
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=0, uid_counter=0)
|
||||||
|
# Changed
|
||||||
|
two_devices.protocols[1].notify_addressed_player_changed(
|
||||||
|
avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await anext(addressed_player_iter)
|
||||||
|
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_player_app_settings():
|
||||||
|
two_devices = await TwoDevices.create_with_avdtp()
|
||||||
|
|
||||||
|
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||||
|
supported_events=[avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED]
|
||||||
|
)
|
||||||
|
delegate.player_app_settings = {
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
|
||||||
|
}
|
||||||
|
settings_iter = two_devices.protocols[0].monitor_player_application_settings()
|
||||||
|
|
||||||
|
# Interim
|
||||||
|
interim = await anext(settings_iter)
|
||||||
|
assert interim[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
|
||||||
|
assert (
|
||||||
|
interim[0].value_id
|
||||||
|
== avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
|
||||||
|
)
|
||||||
|
|
||||||
|
# Changed
|
||||||
|
two_devices.protocols[1].notify_player_application_settings_changed(
|
||||||
|
[
|
||||||
|
avrcp.PlayerApplicationSettingChangedEvent.Setting(
|
||||||
|
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||||
|
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
changed = await anext(settings_iter)
|
||||||
|
assert changed[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
|
||||||
|
assert changed[0].value_id == avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_frame_parser()
|
test_frame_parser()
|
||||||
|
|||||||
@@ -73,6 +73,14 @@ def test_uuid_to_hex_str() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def test_uuid_hash() -> None:
|
||||||
|
uuid = UUID("1234")
|
||||||
|
uuid_128_bytes = UUID.from_bytes(uuid.to_bytes(force_128=True))
|
||||||
|
assert uuid in {uuid_128_bytes}
|
||||||
|
assert uuid_128_bytes in {uuid}
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_appearance() -> None:
|
def test_appearance() -> None:
|
||||||
a = Appearance(Appearance.Category.COMPUTER, Appearance.ComputerSubcategory.LAPTOP)
|
a = Appearance(Appearance.Category.COMPUTER, Appearance.ComputerSubcategory.LAPTOP)
|
||||||
|
|||||||
+40
-1
@@ -309,6 +309,27 @@ async def test_legacy_advertising_disconnection(auto_restart):
|
|||||||
assert not devices[0].is_advertising
|
assert not devices[0].is_advertising
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_le_multiple_connects():
|
||||||
|
devices = TwoDevices()
|
||||||
|
for controller in devices.controllers:
|
||||||
|
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
|
||||||
|
for dev in devices:
|
||||||
|
await dev.power_on()
|
||||||
|
await devices[0].start_advertising(auto_restart=True, advertising_interval_min=1.0)
|
||||||
|
|
||||||
|
connection = await devices[1].connect(devices[0].random_address)
|
||||||
|
await connection.disconnect()
|
||||||
|
|
||||||
|
await async_barrier()
|
||||||
|
await async_barrier()
|
||||||
|
|
||||||
|
# a second connection attempt is working
|
||||||
|
connection = await devices[1].connect(devices[0].random_address)
|
||||||
|
await connection.disconnect()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_advertising_and_scanning():
|
async def test_advertising_and_scanning():
|
||||||
@@ -445,7 +466,9 @@ async def test_get_remote_le_features():
|
|||||||
devices = TwoDevices()
|
devices = TwoDevices()
|
||||||
await devices.setup_connection()
|
await devices.setup_connection()
|
||||||
|
|
||||||
assert (await devices.connections[0].get_remote_le_features()) is not None
|
assert (
|
||||||
|
await devices.connections[0].get_remote_le_features()
|
||||||
|
) == devices.controllers[1].le_features
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -803,6 +826,22 @@ async def test_remote_name_request():
|
|||||||
assert actual_name == expected_name
|
assert actual_name == expected_name
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_remote_classic_features():
|
||||||
|
devices = TwoDevices()
|
||||||
|
devices[0].classic_enabled = True
|
||||||
|
devices[1].classic_enabled = True
|
||||||
|
await devices[0].power_on()
|
||||||
|
await devices[1].power_on()
|
||||||
|
connection = await devices[0].connect_classic(devices[1].public_address)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await asyncio.wait_for(connection.get_remote_classic_features(), _TIMEOUT)
|
||||||
|
== devices.controllers[1].lmp_features
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def run_test_device():
|
async def run_test_device():
|
||||||
await test_device_connect_parallel()
|
await test_device_connect_parallel()
|
||||||
|
|||||||
+31
-19
@@ -22,6 +22,7 @@ import unittest.mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from bumble import controller, hci
|
||||||
from bumble.controller import Controller
|
from bumble.controller import Controller
|
||||||
from bumble.hci import (
|
from bumble.hci import (
|
||||||
HCI_AclDataPacket,
|
HCI_AclDataPacket,
|
||||||
@@ -49,34 +50,27 @@ logger = logging.getLogger(__name__)
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'supported_commands, lmp_features',
|
'supported_commands, max_lmp_features_page_number',
|
||||||
[
|
[
|
||||||
(
|
(controller.Controller.supported_commands, 0),
|
||||||
# Default commands
|
|
||||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
|
||||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000',
|
|
||||||
# Only LE LMP feature
|
|
||||||
'0000000060000000',
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
# All commands
|
# All commands
|
||||||
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
|
set(hci.HCI_Command.command_names.keys()),
|
||||||
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
|
|
||||||
# 3 pages of LMP features
|
# 3 pages of LMP features
|
||||||
'000102030405060708090A0B0C0D0E0F011112131415161718191A1B1C1D1E1F',
|
2,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_reset(supported_commands: str, lmp_features: str):
|
async def test_reset(supported_commands: set[int], max_lmp_features_page_number: int):
|
||||||
controller = Controller('C')
|
controller = Controller('C')
|
||||||
controller.supported_commands = bytes.fromhex(supported_commands)
|
controller.supported_commands = supported_commands
|
||||||
controller.lmp_features = bytes.fromhex(lmp_features)
|
controller.lmp_features_max_page_number = max_lmp_features_page_number
|
||||||
host = Host(controller, AsyncPipeSink(controller))
|
host = Host(controller, AsyncPipeSink(controller))
|
||||||
|
|
||||||
await host.reset()
|
await host.reset()
|
||||||
|
|
||||||
assert host.local_lmp_features == int.from_bytes(
|
assert host.local_lmp_features == (
|
||||||
bytes.fromhex(lmp_features), 'little'
|
controller.lmp_features & ~(1 << (64 * max_lmp_features_page_number + 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,14 +171,15 @@ class Source:
|
|||||||
|
|
||||||
|
|
||||||
class Sink:
|
class Sink:
|
||||||
response: HCI_Event
|
response: HCI_Event | None
|
||||||
|
|
||||||
def __init__(self, source: Source, response: HCI_Event) -> None:
|
def __init__(self, source: Source, response: HCI_Event | None) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
self.response = response
|
self.response = response
|
||||||
|
|
||||||
def on_packet(self, packet: bytes) -> None:
|
def on_packet(self, packet: bytes) -> None:
|
||||||
self.source.sink.on_packet(bytes(self.response))
|
if self.response is not None:
|
||||||
|
self.source.sink.on_packet(bytes(self.response))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -234,6 +229,23 @@ async def test_send_sync_command() -> None:
|
|||||||
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
|
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_sync_command_timeout() -> None:
|
||||||
|
source = Source()
|
||||||
|
sink = Sink(source, None)
|
||||||
|
|
||||||
|
host = Host(source, sink)
|
||||||
|
host.ready = True
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
|
||||||
|
|
||||||
|
# The sending semaphore should have been released, so this should not block
|
||||||
|
# indefinitely
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_async_command() -> None:
|
async def test_send_async_command() -> None:
|
||||||
source = Source()
|
source = Source()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -179,11 +180,55 @@ async def test_default_namespace(temporary_file):
|
|||||||
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
|
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_filename(tmp_path):
|
||||||
|
import platformdirs
|
||||||
|
|
||||||
|
with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path):
|
||||||
|
# Case 1: no namespace, no filename
|
||||||
|
keystore = JsonKeyStore(None, None)
|
||||||
|
expected_directory = tmp_path / 'Pairing'
|
||||||
|
expected_filename = expected_directory / 'keys.json'
|
||||||
|
assert keystore.directory_name == expected_directory
|
||||||
|
assert keystore.filename == expected_filename
|
||||||
|
|
||||||
|
# Save some data
|
||||||
|
keys = PairingKeys()
|
||||||
|
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||||
|
keys.ltk = PairingKeys.Key(ltk)
|
||||||
|
await keystore.update('foo', keys)
|
||||||
|
assert expected_filename.exists()
|
||||||
|
|
||||||
|
# Load back
|
||||||
|
keystore2 = JsonKeyStore(None, None)
|
||||||
|
foo = await keystore2.get('foo')
|
||||||
|
assert foo is not None
|
||||||
|
assert foo.ltk.value == ltk
|
||||||
|
|
||||||
|
# Case 2: namespace, no filename
|
||||||
|
keystore3 = JsonKeyStore('my:namespace', None)
|
||||||
|
# safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-')
|
||||||
|
expected_filename3 = expected_directory / 'my-namespace.json'
|
||||||
|
assert keystore3.filename == expected_filename3
|
||||||
|
|
||||||
|
# Save some data
|
||||||
|
await keystore3.update('bar', keys)
|
||||||
|
assert expected_filename3.exists()
|
||||||
|
|
||||||
|
# Load back
|
||||||
|
keystore4 = JsonKeyStore('my:namespace', None)
|
||||||
|
bar = await keystore4.get('bar')
|
||||||
|
assert bar is not None
|
||||||
|
assert bar.ltk.value == ltk
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def run_tests():
|
async def run_tests():
|
||||||
await test_basic()
|
await test_basic()
|
||||||
await test_parsing()
|
await test_parsing()
|
||||||
await test_default_namespace()
|
await test_default_namespace()
|
||||||
|
await test_no_filename()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -18,9 +18,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from bumble import sdp
|
||||||
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
|
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
|
||||||
from bumble.sdp import (
|
from bumble.sdp import (
|
||||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
@@ -206,6 +208,16 @@ def sdp_records(record_count=1):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def test_pdu_parameter_length(caplog) -> None:
|
||||||
|
caplog.set_level(logging.WARNING)
|
||||||
|
pdu = sdp.SDP_ErrorResponse(
|
||||||
|
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
|
||||||
|
)
|
||||||
|
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
|
||||||
|
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_service_search():
|
async def test_service_search():
|
||||||
@@ -428,3 +440,43 @@ async def run():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
asyncio.run(run())
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def test_nested_sequence_recursion_guard():
|
||||||
|
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
|
||||||
|
the parser with RecursionError. Instead a ValueError is raised once the
|
||||||
|
configured nesting limit is exceeded.
|
||||||
|
|
||||||
|
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
|
||||||
|
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
|
||||||
|
without a depth limit. A malicious SDP peer could craft a PDU exceeding
|
||||||
|
Pythons default recursion limit (~1000 frames) and crash the host.
|
||||||
|
"""
|
||||||
|
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
|
||||||
|
inner = b"\x35\x00" # empty SEQUENCE terminator
|
||||||
|
for _ in range(1500):
|
||||||
|
size = len(inner)
|
||||||
|
if size >= 65535:
|
||||||
|
break
|
||||||
|
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="nesting exceeds max depth"):
|
||||||
|
DataElement.from_bytes(inner)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_sequence_within_limit_still_works():
|
||||||
|
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
|
||||||
|
leaf = DataElement.unsigned_integer(1, value_size=2)
|
||||||
|
payload = leaf
|
||||||
|
for _ in range(16): # under the 32-depth limit
|
||||||
|
payload = DataElement.sequence([payload])
|
||||||
|
raw = bytes(payload)
|
||||||
|
parsed = DataElement.from_bytes(raw)
|
||||||
|
# Walk back down to confirm structural integrity preserved
|
||||||
|
cur = parsed
|
||||||
|
for _ in range(16):
|
||||||
|
assert cur.type == DataElement.SEQUENCE
|
||||||
|
cur = cur.value[0]
|
||||||
|
assert cur.type == DataElement.UNSIGNED_INTEGER
|
||||||
|
assert cur.value == 1
|
||||||
|
|||||||
+5
-6
@@ -29,8 +29,7 @@ from bumble.gatt import Characteristic, Service
|
|||||||
from bumble.hci import Role
|
from bumble.hci import Role
|
||||||
from bumble.pairing import PairingConfig, PairingDelegate
|
from bumble.pairing import PairingConfig, PairingDelegate
|
||||||
from bumble.smp import (
|
from bumble.smp import (
|
||||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
ErrorCode,
|
||||||
SMP_PAIRING_NOT_SUPPORTED_ERROR,
|
|
||||||
OobContext,
|
OobContext,
|
||||||
OobLegacyContext,
|
OobLegacyContext,
|
||||||
)
|
)
|
||||||
@@ -378,7 +377,7 @@ async def test_self_smp_reject():
|
|||||||
await _test_self_smp_with_configs(None, rejecting_pairing_config)
|
await _test_self_smp_with_configs(None, rejecting_pairing_config)
|
||||||
paired = True
|
paired = True
|
||||||
except ProtocolError as error:
|
except ProtocolError as error:
|
||||||
assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
|
assert error.error_code == ErrorCode.PAIRING_NOT_SUPPORTED
|
||||||
|
|
||||||
assert not paired
|
assert not paired
|
||||||
|
|
||||||
@@ -403,7 +402,7 @@ async def test_self_smp_wrong_pin():
|
|||||||
)
|
)
|
||||||
paired = True
|
paired = True
|
||||||
except ProtocolError as error:
|
except ProtocolError as error:
|
||||||
assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
assert error.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
|
|
||||||
assert not paired
|
assert not paired
|
||||||
|
|
||||||
@@ -534,11 +533,11 @@ async def test_self_smp_oob_sc():
|
|||||||
|
|
||||||
with pytest.raises(ProtocolError) as error:
|
with pytest.raises(ProtocolError) as error:
|
||||||
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
|
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
|
||||||
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
|
|
||||||
with pytest.raises(ProtocolError):
|
with pytest.raises(ProtocolError):
|
||||||
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
|
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
|
||||||
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
+15
-1
@@ -24,7 +24,7 @@ import pytest
|
|||||||
from bumble import crypto, pairing, smp
|
from bumble import crypto, pairing, smp
|
||||||
from bumble.core import AdvertisingData
|
from bumble.core import AdvertisingData
|
||||||
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
||||||
from bumble.device import Device
|
from bumble.device import Device, DeviceConfiguration
|
||||||
from bumble.hci import Address
|
from bumble.hci import Address
|
||||||
from bumble.pairing import LeRole, OobData, OobSharedData
|
from bumble.pairing import LeRole, OobData, OobSharedData
|
||||||
|
|
||||||
@@ -312,3 +312,17 @@ async def test_send_identity_address_command(
|
|||||||
actual_command = mock_method.call_args.args[0]
|
actual_command = mock_method.call_args.args[0]
|
||||||
assert actual_command.addr_type == expected_identity_address.address_type
|
assert actual_command.addr_type == expected_identity_address.address_type
|
||||||
assert actual_command.bd_addr == expected_identity_address
|
assert actual_command.bd_addr == expected_identity_address
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smp_debug_mode():
|
||||||
|
config = DeviceConfiguration(smp_debug_mode=True)
|
||||||
|
device = Device(config=config)
|
||||||
|
|
||||||
|
assert device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
|
||||||
|
assert device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
|
||||||
|
|
||||||
|
device.smp_manager.debug_mode = False
|
||||||
|
|
||||||
|
assert not device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
|
||||||
|
assert not device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
||||||
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" />
|
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" />
|
||||||
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
|
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
|
||||||
<script type="module" src="../ui.js"></script>
|
<script type="module" src="../ui.js"></script>
|
||||||
<script type="module" src="heart_rate_monitor.js"></script>
|
<script type="module" src="heart_rate_monitor.js"></script>
|
||||||
<style>
|
<style>
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
||||||
<link rel="stylesheet" href="scanner.css">
|
<link rel="stylesheet" href="scanner.css">
|
||||||
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
|
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
|
||||||
<script type="module" src="../ui.js"></script>
|
<script type="module" src="../ui.js"></script>
|
||||||
<script type="module" src="scanner.js"></script>
|
<script type="module" src="scanner.js"></script>
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
<title>Bumble Speaker</title>
|
<title>Bumble Speaker</title>
|
||||||
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
||||||
<link rel="stylesheet" href="speaker.css">
|
<link rel="stylesheet" href="speaker.css">
|
||||||
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
|
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
|
||||||
<script type="module" src="speaker.js"></script>
|
<script type="module" src="speaker.js"></script>
|
||||||
<script type="module" src="../ui.js"></script>
|
<script type="module" src="../ui.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|||||||
Reference in New Issue
Block a user