mirror of
https://github.com/google/bumble.git
synced 2026-05-06 03:38:01 +00:00
Compare commits
45 Commits
gbg/hci-fi
...
v0.0.226
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb64debb62 | ||
|
|
c158f25b1e | ||
|
|
1330e83517 | ||
|
|
d9c9bea6cb | ||
|
|
3b937631b3 | ||
|
|
f8aa309111 | ||
|
|
673281ed71 | ||
|
|
3ac7af4683 | ||
|
|
5ebfaae74e | ||
|
|
e6175f85fe | ||
|
|
f9ba527508 | ||
|
|
a407c4cabf | ||
|
|
6c2d6dddb5 | ||
|
|
797cd216d4 | ||
|
|
e2e8c90e47 | ||
|
|
3d5648cdc3 | ||
|
|
d810d93aaf | ||
|
|
81d9adb983 | ||
|
|
377fa896f7 | ||
|
|
79e5974946 | ||
|
|
657451474e | ||
|
|
9f730dce6f | ||
|
|
1a6be95a7e | ||
|
|
aea5320d71 | ||
|
|
91cb1b1df3 | ||
|
|
81bdc86e52 | ||
|
|
f23cad34e3 | ||
|
|
30fde2c00b | ||
|
|
256a1a7405 | ||
|
|
116d9b26bb | ||
|
|
aabe2ca063 | ||
|
|
2d17a5f742 | ||
|
|
3894b14467 | ||
|
|
e62f947430 | ||
|
|
dcb8a4b607 | ||
|
|
7118328b07 | ||
|
|
5dc01d792a | ||
|
|
255f357975 | ||
|
|
c86920558b | ||
|
|
8e6efd0b2f | ||
|
|
34f5b81c7d | ||
|
|
d34d6a5c98 | ||
|
|
aedc971653 | ||
|
|
c6815fb820 | ||
|
|
85b78b46f8 |
@@ -24,13 +24,18 @@ import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import secrets
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
import click
|
||||
import tomli
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
import tomli as tomllib
|
||||
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
@@ -114,7 +119,7 @@ def parse_broadcast_list(filename: str) -> Sequence[Broadcast]:
|
||||
broadcasts: list[Broadcast] = []
|
||||
|
||||
with open(filename, "rb") as config_file:
|
||||
config = tomli.load(config_file)
|
||||
config = tomllib.load(config_file)
|
||||
for broadcast in config.get("broadcasts", []):
|
||||
sources = []
|
||||
for source in broadcast.get("sources", []):
|
||||
|
||||
15
apps/pair.py
15
apps/pair.py
@@ -20,11 +20,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
import click
|
||||
from prompt_toolkit.shortcuts import PromptSession
|
||||
|
||||
from bumble import data_types
|
||||
from bumble import data_types, smp
|
||||
from bumble.a2dp import make_audio_sink_service_sdp_records
|
||||
from bumble.att import (
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
||||
@@ -40,7 +41,7 @@ from bumble.core import (
|
||||
PhysicalTransport,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.device import Connection, Device, Peer
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
@@ -53,7 +54,6 @@ from bumble.hci import OwnAddressType
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.pairing import OobData, PairingConfig, PairingDelegate
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
@@ -65,7 +65,7 @@ POST_PAIRING_DELAY = 1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Waiter:
|
||||
instance: Waiter | None = None
|
||||
instance: ClassVar[Waiter | None] = None
|
||||
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
@@ -319,12 +319,13 @@ async def on_classic_pairing(connection):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_pairing_failure(connection, reason):
|
||||
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode):
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
|
||||
print(color(f'*** Pairing failed: {reason.name}', 'red'))
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
await connection.disconnect()
|
||||
Waiter.instance.terminate()
|
||||
if Waiter.instance:
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
12
apps/scan.py
12
apps/scan.py
@@ -22,7 +22,7 @@ import click
|
||||
import bumble.logging
|
||||
from bumble import data_types
|
||||
from bumble.colors import color
|
||||
from bumble.device import Advertisement, Device
|
||||
from bumble.device import Advertisement, Device, DeviceConfiguration
|
||||
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.smp import AddressResolver
|
||||
@@ -144,8 +144,14 @@ async def scan(
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
device = Device.with_hci(
|
||||
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||
device = Device.from_config_with_hci(
|
||||
DeviceConfiguration(
|
||||
name='Bumble',
|
||||
address=Address('F0:F1:F2:F3:F4:F5'),
|
||||
keystore='JsonKeyStore',
|
||||
),
|
||||
hci_source,
|
||||
hci_sink,
|
||||
)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
@@ -88,13 +88,6 @@ SBC_DUAL_CHANNEL_MODE = 0x01
|
||||
SBC_STEREO_CHANNEL_MODE = 0x02
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
|
||||
|
||||
SBC_CHANNEL_MODE_NAMES = {
|
||||
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
|
||||
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
|
||||
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
|
||||
}
|
||||
|
||||
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
|
||||
|
||||
SBC_SUBBANDS = [4, 8]
|
||||
@@ -102,11 +95,6 @@ SBC_SUBBANDS = [4, 8]
|
||||
SBC_SNR_ALLOCATION_METHOD = 0x00
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
|
||||
|
||||
SBC_ALLOCATION_METHOD_NAMES = {
|
||||
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
||||
}
|
||||
|
||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||
@@ -129,13 +117,6 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
|
||||
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
|
||||
|
||||
MPEG_2_4_OBJECT_TYPE_NAMES = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
|
||||
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
||||
}
|
||||
|
||||
|
||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
@@ -267,26 +248,27 @@ class MediaCodecInformation:
|
||||
def create(
|
||||
cls, media_codec_type: int, data: bytes
|
||||
) -> MediaCodecInformation | bytes:
|
||||
if media_codec_type == CodecType.SBC:
|
||||
return SbcMediaCodecInformation.from_bytes(data)
|
||||
elif media_codec_type == CodecType.MPEG_2_4_AAC:
|
||||
return AacMediaCodecInformation.from_bytes(data)
|
||||
elif media_codec_type == CodecType.NON_A2DP:
|
||||
vendor_media_codec_information = (
|
||||
VendorSpecificMediaCodecInformation.from_bytes(data)
|
||||
)
|
||||
if (
|
||||
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
||||
vendor_media_codec_information.vendor_id
|
||||
)
|
||||
) and (
|
||||
media_codec_information_class := vendor_class_map.get(
|
||||
vendor_media_codec_information.codec_id
|
||||
)
|
||||
):
|
||||
return media_codec_information_class.from_bytes(
|
||||
vendor_media_codec_information.value
|
||||
match media_codec_type:
|
||||
case CodecType.SBC:
|
||||
return SbcMediaCodecInformation.from_bytes(data)
|
||||
case CodecType.MPEG_2_4_AAC:
|
||||
return AacMediaCodecInformation.from_bytes(data)
|
||||
case CodecType.NON_A2DP:
|
||||
vendor_media_codec_information = (
|
||||
VendorSpecificMediaCodecInformation.from_bytes(data)
|
||||
)
|
||||
if (
|
||||
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
||||
vendor_media_codec_information.vendor_id
|
||||
)
|
||||
) and (
|
||||
media_codec_information_class := vendor_class_map.get(
|
||||
vendor_media_codec_information.codec_id
|
||||
)
|
||||
):
|
||||
return media_codec_information_class.from_bytes(
|
||||
vendor_media_codec_information.value
|
||||
)
|
||||
return vendor_media_codec_information
|
||||
|
||||
@classmethod
|
||||
|
||||
62
bumble/at.py
62
bumble/at.py
@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
|
||||
are ignored [..], unless they are embedded in numeric or string constants"
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = []
|
||||
tokens: list[bytearray] = []
|
||||
in_quotes = False
|
||||
token = bytearray()
|
||||
for b in buffer:
|
||||
@@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
|
||||
tokens.append(token[1:-1])
|
||||
token = bytearray()
|
||||
else:
|
||||
if char == b' ':
|
||||
pass
|
||||
elif char == b',' or char == b')':
|
||||
tokens.append(token)
|
||||
tokens.append(char)
|
||||
token = bytearray()
|
||||
elif char == b'(':
|
||||
if len(token) > 0:
|
||||
raise AtParsingError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
elif char == b'"':
|
||||
if len(token) > 0:
|
||||
raise AtParsingError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
else:
|
||||
token.extend(char)
|
||||
match char:
|
||||
case b' ':
|
||||
pass
|
||||
case b',' | b')':
|
||||
tokens.append(token)
|
||||
tokens.append(char)
|
||||
token = bytearray()
|
||||
case b'(':
|
||||
if len(token) > 0:
|
||||
raise AtParsingError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
case b'"':
|
||||
if len(token) > 0:
|
||||
raise AtParsingError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
case _:
|
||||
token.extend(char)
|
||||
|
||||
tokens.append(token)
|
||||
return [bytes(token) for token in tokens if len(token) > 0]
|
||||
@@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
|
||||
current: bytes | list = b''
|
||||
|
||||
for token in tokens:
|
||||
if token == b',':
|
||||
accumulator[-1].append(current)
|
||||
current = b''
|
||||
elif token == b'(':
|
||||
accumulator.append([])
|
||||
elif token == b')':
|
||||
if len(accumulator) < 2:
|
||||
raise AtParsingError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
else:
|
||||
current = token
|
||||
match token:
|
||||
case b',':
|
||||
accumulator[-1].append(current)
|
||||
current = b''
|
||||
case b'(':
|
||||
accumulator.append([])
|
||||
case b')':
|
||||
if len(accumulator) < 2:
|
||||
raise AtParsingError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
case _:
|
||||
current = token
|
||||
|
||||
accumulator[-1].append(current)
|
||||
if len(accumulator) > 1:
|
||||
|
||||
103
bumble/att.py
103
bumble/att.py
@@ -954,12 +954,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
self.permissions = permissions
|
||||
|
||||
# Convert the type to a UUID object if it isn't already
|
||||
if isinstance(attribute_type, str):
|
||||
self.type = UUID(attribute_type)
|
||||
elif isinstance(attribute_type, bytes):
|
||||
self.type = UUID.from_bytes(attribute_type)
|
||||
else:
|
||||
self.type = attribute_type
|
||||
match attribute_type:
|
||||
case str():
|
||||
self.type = UUID(attribute_type)
|
||||
case bytes():
|
||||
self.type = UUID.from_bytes(attribute_type)
|
||||
case _:
|
||||
self.type = attribute_type
|
||||
|
||||
self.value = value
|
||||
|
||||
@@ -994,30 +995,31 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
)
|
||||
|
||||
value: _T | None
|
||||
if isinstance(self.value, AttributeValue):
|
||||
try:
|
||||
read_value = self.value.read(connection)
|
||||
if inspect.isawaitable(read_value):
|
||||
value = await read_value
|
||||
else:
|
||||
value = read_value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
elif isinstance(self.value, AttributeValueV2):
|
||||
try:
|
||||
read_value = self.value.read(bearer)
|
||||
if inspect.isawaitable(read_value):
|
||||
value = await read_value
|
||||
else:
|
||||
value = read_value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
else:
|
||||
value = self.value
|
||||
match self.value:
|
||||
case AttributeValue():
|
||||
try:
|
||||
read_value = self.value.read(connection)
|
||||
if inspect.isawaitable(read_value):
|
||||
value = await read_value
|
||||
else:
|
||||
value = read_value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
case AttributeValueV2():
|
||||
try:
|
||||
read_value = self.value.read(bearer)
|
||||
if inspect.isawaitable(read_value):
|
||||
value = await read_value
|
||||
else:
|
||||
value = read_value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
case _:
|
||||
value = self.value
|
||||
|
||||
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
|
||||
|
||||
@@ -1049,26 +1051,27 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
|
||||
decoded_value = self.decode_value(value)
|
||||
|
||||
if isinstance(self.value, AttributeValue):
|
||||
try:
|
||||
result = self.value.write(connection, decoded_value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
elif isinstance(self.value, AttributeValueV2):
|
||||
try:
|
||||
result = self.value.write(bearer, decoded_value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
else:
|
||||
self.value = decoded_value
|
||||
match self.value:
|
||||
case AttributeValue():
|
||||
try:
|
||||
result = self.value.write(connection, decoded_value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
case AttributeValueV2():
|
||||
try:
|
||||
result = self.value.write(bearer, decoded_value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
case _:
|
||||
self.value = decoded_value
|
||||
|
||||
self.emit(self.EVENT_WRITE, connection, decoded_value)
|
||||
|
||||
|
||||
911
bumble/avrcp.py
911
bumble/avrcp.py
File diff suppressed because it is too large
Load Diff
1079
bumble/controller.py
1079
bumble/controller.py
File diff suppressed because it is too large
Load Diff
138
bumble/core.py
138
bumble/core.py
@@ -280,14 +280,15 @@ class UUID:
|
||||
if not force_128:
|
||||
return self.uuid_bytes
|
||||
|
||||
if len(self.uuid_bytes) == 2:
|
||||
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
|
||||
elif len(self.uuid_bytes) == 4:
|
||||
return self.BASE_UUID + self.uuid_bytes
|
||||
elif len(self.uuid_bytes) == 16:
|
||||
return self.uuid_bytes
|
||||
else:
|
||||
assert False, "unreachable"
|
||||
match len(self.uuid_bytes):
|
||||
case 2:
|
||||
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
|
||||
case 4:
|
||||
return self.BASE_UUID + self.uuid_bytes
|
||||
case 16:
|
||||
return self.uuid_bytes
|
||||
case _:
|
||||
assert False, "unreachable"
|
||||
|
||||
def to_pdu_bytes(self) -> bytes:
|
||||
'''
|
||||
@@ -1769,66 +1770,71 @@ class AdvertisingData:
|
||||
|
||||
@classmethod
|
||||
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
|
||||
if ad_type == AdvertisingData.FLAGS:
|
||||
ad_type_str = 'Flags'
|
||||
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
|
||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:2])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
|
||||
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:4])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
|
||||
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:16])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
|
||||
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME:
|
||||
ad_type_str = 'Shortened Local Name'
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
|
||||
ad_type_str = 'Complete Local Name'
|
||||
try:
|
||||
match ad_type:
|
||||
case AdvertisingData.FLAGS:
|
||||
ad_type_str = 'Flags'
|
||||
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
|
||||
case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||
case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
|
||||
case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||
case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
|
||||
case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||
case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
|
||||
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
|
||||
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
|
||||
case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:2])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
|
||||
case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:4])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
|
||||
case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
|
||||
ad_type_str = 'Service Data'
|
||||
uuid = UUID.from_bytes(ad_data[:16])
|
||||
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
|
||||
case AdvertisingData.SHORTENED_LOCAL_NAME:
|
||||
ad_type_str = 'Shortened Local Name'
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
except UnicodeDecodeError:
|
||||
case AdvertisingData.COMPLETE_LOCAL_NAME:
|
||||
ad_type_str = 'Complete Local Name'
|
||||
try:
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
except UnicodeDecodeError:
|
||||
ad_data_str = ad_data.hex()
|
||||
case AdvertisingData.TX_POWER_LEVEL:
|
||||
ad_type_str = 'TX Power Level'
|
||||
ad_data_str = str(ad_data[0])
|
||||
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
||||
ad_type_str = 'Manufacturer Specific Data'
|
||||
company_id = struct.unpack_from('<H', ad_data, 0)[0]
|
||||
company_name = COMPANY_IDENTIFIERS.get(
|
||||
company_id, f'0x{company_id:04X}'
|
||||
)
|
||||
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
|
||||
case AdvertisingData.APPEARANCE:
|
||||
ad_type_str = 'Appearance'
|
||||
appearance = Appearance.from_int(
|
||||
struct.unpack_from('<H', ad_data, 0)[0]
|
||||
)
|
||||
ad_data_str = str(appearance)
|
||||
case AdvertisingData.BROADCAST_NAME:
|
||||
ad_type_str = 'Broadcast Name'
|
||||
ad_data_str = ad_data.decode('utf-8')
|
||||
case _:
|
||||
ad_type_str = AdvertisingData.Type(ad_type).name
|
||||
ad_data_str = ad_data.hex()
|
||||
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
|
||||
ad_type_str = 'TX Power Level'
|
||||
ad_data_str = str(ad_data[0])
|
||||
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
||||
ad_type_str = 'Manufacturer Specific Data'
|
||||
company_id = struct.unpack_from('<H', ad_data, 0)[0]
|
||||
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
|
||||
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
|
||||
elif ad_type == AdvertisingData.APPEARANCE:
|
||||
ad_type_str = 'Appearance'
|
||||
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
|
||||
ad_data_str = str(appearance)
|
||||
elif ad_type == AdvertisingData.BROADCAST_NAME:
|
||||
ad_type_str = 'Broadcast Name'
|
||||
ad_data_str = ad_data.decode('utf-8')
|
||||
else:
|
||||
ad_type_str = AdvertisingData.Type(ad_type).name
|
||||
ad_data_str = ad_data.hex()
|
||||
|
||||
return f'[{ad_type_str}]: {ad_data_str}'
|
||||
|
||||
|
||||
574
bumble/device.py
574
bumble/device.py
@@ -3748,6 +3748,292 @@ class Device(utils.CompositeEventEmitter):
|
||||
page_scan_enabled=self.connectable,
|
||||
)
|
||||
|
||||
async def connect_le(
|
||||
self,
|
||||
peer_address: hci.Address | str,
|
||||
connection_parameters_preferences: (
|
||||
dict[hci.Phy, ConnectionParametersPreferences] | None
|
||||
) = None,
|
||||
own_address_type: hci.OwnAddressType = hci.OwnAddressType.RANDOM,
|
||||
timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
) -> Connection:
|
||||
# Check that there isn't already a pending connection
|
||||
if self.is_le_connecting:
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
try_resolve = not self.address_resolution_offload
|
||||
if isinstance(peer_address, str):
|
||||
try:
|
||||
peer_address = hci.Address.from_string_for_transport(
|
||||
peer_address, PhysicalTransport.LE
|
||||
)
|
||||
except (InvalidArgumentError, ValueError):
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, PhysicalTransport.LE
|
||||
) # TODO: timeout
|
||||
try_resolve = False
|
||||
|
||||
assert isinstance(peer_address, hci.Address)
|
||||
|
||||
if (
|
||||
try_resolve
|
||||
and self.address_resolver is not None
|
||||
and self.address_resolver.can_resolve_to(peer_address)
|
||||
):
|
||||
# If we have an IRK for this address, we should resolve.
|
||||
logger.debug('have IRK for address, resolving...')
|
||||
peer_address = await self.find_peer_by_identity_address(
|
||||
peer_address
|
||||
) # TODO: timeout
|
||||
|
||||
def on_connection(connection):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error: core.ConnectionError):
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
self.on(self.EVENT_CONNECTION, on_connection)
|
||||
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
|
||||
try:
|
||||
# Tell the controller to connect
|
||||
if connection_parameters_preferences is None:
|
||||
connection_parameters_preferences = {
|
||||
hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default
|
||||
}
|
||||
|
||||
self.connect_own_address_type = own_address_type
|
||||
|
||||
if self.host.supports_command(
|
||||
hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND
|
||||
):
|
||||
# Only keep supported PHYs
|
||||
phys = sorted(
|
||||
list(
|
||||
set(
|
||||
filter(
|
||||
self.supports_le_phy,
|
||||
connection_parameters_preferences.keys(),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
if not phys:
|
||||
raise InvalidArgumentError('at least one supported PHY needed')
|
||||
|
||||
phy_count = len(phys)
|
||||
initiating_phys = hci.phy_list_to_bits(phys)
|
||||
|
||||
connection_interval_mins = [
|
||||
int(
|
||||
connection_parameters_preferences[phy].connection_interval_min
|
||||
/ 1.25
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
connection_interval_maxs = [
|
||||
int(
|
||||
connection_parameters_preferences[phy].connection_interval_max
|
||||
/ 1.25
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
max_latencies = [
|
||||
connection_parameters_preferences[phy].max_latency for phy in phys
|
||||
]
|
||||
supervision_timeouts = [
|
||||
int(connection_parameters_preferences[phy].supervision_timeout / 10)
|
||||
for phy in phys
|
||||
]
|
||||
min_ce_lengths = [
|
||||
int(connection_parameters_preferences[phy].min_ce_length / 0.625)
|
||||
for phy in phys
|
||||
]
|
||||
max_ce_lengths = [
|
||||
int(connection_parameters_preferences[phy].max_ce_length / 0.625)
|
||||
for phy in phys
|
||||
]
|
||||
|
||||
await self.send_async_command(
|
||||
hci.HCI_LE_Extended_Create_Connection_Command(
|
||||
initiator_filter_policy=0,
|
||||
own_address_type=own_address_type,
|
||||
peer_address_type=peer_address.address_type,
|
||||
peer_address=peer_address,
|
||||
initiating_phys=initiating_phys,
|
||||
scan_intervals=(
|
||||
int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625),
|
||||
)
|
||||
* phy_count,
|
||||
scan_windows=(int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),)
|
||||
* phy_count,
|
||||
connection_interval_mins=connection_interval_mins,
|
||||
connection_interval_maxs=connection_interval_maxs,
|
||||
max_latencies=max_latencies,
|
||||
supervision_timeouts=supervision_timeouts,
|
||||
min_ce_lengths=min_ce_lengths,
|
||||
max_ce_lengths=max_ce_lengths,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if hci.HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||
raise InvalidArgumentError('1M PHY preferences required')
|
||||
|
||||
prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY]
|
||||
await self.send_async_command(
|
||||
hci.HCI_LE_Create_Connection_Command(
|
||||
le_scan_interval=int(
|
||||
DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625
|
||||
),
|
||||
le_scan_window=int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),
|
||||
initiator_filter_policy=0,
|
||||
peer_address_type=peer_address.address_type,
|
||||
peer_address=peer_address,
|
||||
own_address_type=own_address_type,
|
||||
connection_interval_min=int(
|
||||
prefs.connection_interval_min / 1.25
|
||||
),
|
||||
connection_interval_max=int(
|
||||
prefs.connection_interval_max / 1.25
|
||||
),
|
||||
max_latency=prefs.max_latency,
|
||||
supervision_timeout=int(prefs.supervision_timeout / 10),
|
||||
min_ce_length=int(prefs.min_ce_length / 0.625),
|
||||
max_ce_length=int(prefs.max_ce_length / 0.625),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the connection process to complete
|
||||
self.le_connecting = True
|
||||
|
||||
if timeout is None:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(pending_connection), timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self.send_sync_command(
|
||||
hci.HCI_LE_Create_Connection_Cancel_Command()
|
||||
)
|
||||
|
||||
try:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
except core.ConnectionError as error:
|
||||
raise core.TimeoutError() from error
|
||||
finally:
|
||||
self.remove_listener(self.EVENT_CONNECTION, on_connection)
|
||||
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
self.le_connecting = False
|
||||
self.connect_own_address_type = None
|
||||
|
||||
async def connect_classic(
|
||||
self,
|
||||
peer_address: hci.Address | str,
|
||||
timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
) -> Connection:
|
||||
if isinstance(peer_address, str):
|
||||
try:
|
||||
peer_address = hci.Address.from_string_for_transport(
|
||||
peer_address, PhysicalTransport.BR_EDR
|
||||
)
|
||||
except (InvalidArgumentError, ValueError):
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, PhysicalTransport.BR_EDR
|
||||
) # TODO: timeout
|
||||
else:
|
||||
# All BR/EDR addresses should be public addresses
|
||||
if peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS:
|
||||
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
|
||||
|
||||
assert isinstance(peer_address, hci.Address)
|
||||
|
||||
def on_connection(connection):
|
||||
if (
|
||||
# match BR/EDR connection event against peer address
|
||||
connection.transport == PhysicalTransport.BR_EDR
|
||||
and connection.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error: core.ConnectionError):
|
||||
if (
|
||||
# match BR/EDR connection failure event against peer address
|
||||
error.transport == PhysicalTransport.BR_EDR
|
||||
and error.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
self.on(self.EVENT_CONNECTION, on_connection)
|
||||
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
|
||||
try:
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection(
|
||||
device=self,
|
||||
handle=0,
|
||||
transport=core.PhysicalTransport.BR_EDR,
|
||||
self_address=self.public_address,
|
||||
self_resolvable_address=None,
|
||||
peer_address=peer_address,
|
||||
peer_resolvable_address=None,
|
||||
role=hci.Role.CENTRAL,
|
||||
parameters=Connection.Parameters(0, 0, 0),
|
||||
)
|
||||
|
||||
# TODO: allow passing other settings
|
||||
await self.send_async_command(
|
||||
hci.HCI_Create_Connection_Command(
|
||||
bd_addr=peer_address,
|
||||
packet_type=0xCC18, # FIXME: change
|
||||
page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE,
|
||||
clock_offset=0x0000,
|
||||
allow_role_switch=0x01,
|
||||
reserved=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the connection process to complete
|
||||
if timeout is None:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(pending_connection), timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self.send_sync_command(
|
||||
hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
|
||||
)
|
||||
|
||||
try:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
except core.ConnectionError as error:
|
||||
raise core.TimeoutError() from error
|
||||
finally:
|
||||
self.remove_listener(self.EVENT_CONNECTION, on_connection)
|
||||
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
peer_address: hci.Address | str,
|
||||
@@ -3769,9 +4055,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
peer_address:
|
||||
hci.Address or name of the device to connect to.
|
||||
If a string is passed:
|
||||
If the string is an address followed by a `@` suffix, the `always_resolve`
|
||||
argument is implicitly set to True, so the connection is made to the
|
||||
address after resolution.
|
||||
[deprecated] If the string is an address followed by a `@` suffix, the
|
||||
`always_resolve`argument is implicitly set to True, so the connection is
|
||||
made to the address after resolution.
|
||||
If the string is any other address, the connection is made to that
|
||||
address (with or without address resolution, depending on the
|
||||
`always_resolve` argument).
|
||||
@@ -3795,271 +4081,29 @@ class Device(utils.CompositeEventEmitter):
|
||||
Pass None for an unlimited time.
|
||||
|
||||
always_resolve:
|
||||
(BLE only, ignored for BR/EDR)
|
||||
If True, always initiate a scan, resolving addresses, and connect to the
|
||||
address that resolves to `peer_address`.
|
||||
[deprecated] (ignore)
|
||||
'''
|
||||
|
||||
# Check parameters
|
||||
if transport not in (PhysicalTransport.LE, PhysicalTransport.BR_EDR):
|
||||
raise InvalidArgumentError('invalid transport')
|
||||
transport = core.PhysicalTransport(transport)
|
||||
# Connect using the appropriate transport
|
||||
# (auto-correct the transport based on declared capabilities)
|
||||
if transport == PhysicalTransport.LE or (
|
||||
self.le_enabled and not self.classic_enabled
|
||||
):
|
||||
return await self.connect_le(
|
||||
peer_address=peer_address,
|
||||
connection_parameters_preferences=connection_parameters_preferences,
|
||||
own_address_type=own_address_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Adjust the transport automatically if we need to
|
||||
if transport == PhysicalTransport.LE and not self.le_enabled:
|
||||
transport = PhysicalTransport.BR_EDR
|
||||
elif transport == PhysicalTransport.BR_EDR and not self.classic_enabled:
|
||||
transport = PhysicalTransport.LE
|
||||
if transport == PhysicalTransport.BR_EDR or (
|
||||
self.classic_enabled and not self.le_enabled
|
||||
):
|
||||
return await self.connect_classic(
|
||||
peer_address=peer_address, timeout=timeout
|
||||
)
|
||||
|
||||
# Check that there isn't already a pending connection
|
||||
if transport == PhysicalTransport.LE and self.is_le_connecting:
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
if isinstance(peer_address, str):
|
||||
try:
|
||||
if transport == PhysicalTransport.LE and peer_address.endswith('@'):
|
||||
peer_address = hci.Address.from_string_for_transport(
|
||||
peer_address[:-1], transport
|
||||
)
|
||||
always_resolve = True
|
||||
logger.debug('forcing address resolution')
|
||||
else:
|
||||
peer_address = hci.Address.from_string_for_transport(
|
||||
peer_address, transport
|
||||
)
|
||||
except (InvalidArgumentError, ValueError):
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
always_resolve = False
|
||||
logger.debug('looking for peer by name')
|
||||
assert isinstance(peer_address, str)
|
||||
peer_address = await self.find_peer_by_name(
|
||||
peer_address, transport
|
||||
) # TODO: timeout
|
||||
else:
|
||||
# All BR/EDR addresses should be public addresses
|
||||
if (
|
||||
transport == PhysicalTransport.BR_EDR
|
||||
and peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
):
|
||||
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
|
||||
|
||||
assert isinstance(peer_address, hci.Address)
|
||||
|
||||
if transport == PhysicalTransport.LE and always_resolve:
|
||||
logger.debug('resolving address')
|
||||
peer_address = await self.find_peer_by_identity_address(
|
||||
peer_address
|
||||
) # TODO: timeout
|
||||
|
||||
def on_connection(connection):
|
||||
if transport == PhysicalTransport.LE or (
|
||||
# match BR/EDR connection event against peer address
|
||||
connection.transport == transport
|
||||
and connection.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error: core.ConnectionError):
|
||||
if transport == PhysicalTransport.LE or (
|
||||
# match BR/EDR connection failure event against peer address
|
||||
error.transport == transport
|
||||
and error.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection's result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
self.on(self.EVENT_CONNECTION, on_connection)
|
||||
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
|
||||
try:
|
||||
# Tell the controller to connect
|
||||
if transport == PhysicalTransport.LE:
|
||||
if connection_parameters_preferences is None:
|
||||
if connection_parameters_preferences is None:
|
||||
connection_parameters_preferences = {
|
||||
hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default
|
||||
}
|
||||
|
||||
self.connect_own_address_type = own_address_type
|
||||
|
||||
if self.host.supports_command(
|
||||
hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND
|
||||
):
|
||||
# Only keep supported PHYs
|
||||
phys = sorted(
|
||||
list(
|
||||
set(
|
||||
filter(
|
||||
self.supports_le_phy,
|
||||
connection_parameters_preferences.keys(),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
if not phys:
|
||||
raise InvalidArgumentError('at least one supported PHY needed')
|
||||
|
||||
phy_count = len(phys)
|
||||
initiating_phys = hci.phy_list_to_bits(phys)
|
||||
|
||||
connection_interval_mins = [
|
||||
int(
|
||||
connection_parameters_preferences[
|
||||
phy
|
||||
].connection_interval_min
|
||||
/ 1.25
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
connection_interval_maxs = [
|
||||
int(
|
||||
connection_parameters_preferences[
|
||||
phy
|
||||
].connection_interval_max
|
||||
/ 1.25
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
max_latencies = [
|
||||
connection_parameters_preferences[phy].max_latency
|
||||
for phy in phys
|
||||
]
|
||||
supervision_timeouts = [
|
||||
int(
|
||||
connection_parameters_preferences[phy].supervision_timeout
|
||||
/ 10
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
min_ce_lengths = [
|
||||
int(
|
||||
connection_parameters_preferences[phy].min_ce_length / 0.625
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
max_ce_lengths = [
|
||||
int(
|
||||
connection_parameters_preferences[phy].max_ce_length / 0.625
|
||||
)
|
||||
for phy in phys
|
||||
]
|
||||
|
||||
await self.send_async_command(
|
||||
hci.HCI_LE_Extended_Create_Connection_Command(
|
||||
initiator_filter_policy=0,
|
||||
own_address_type=own_address_type,
|
||||
peer_address_type=peer_address.address_type,
|
||||
peer_address=peer_address,
|
||||
initiating_phys=initiating_phys,
|
||||
scan_intervals=(
|
||||
int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625),
|
||||
)
|
||||
* phy_count,
|
||||
scan_windows=(
|
||||
int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),
|
||||
)
|
||||
* phy_count,
|
||||
connection_interval_mins=connection_interval_mins,
|
||||
connection_interval_maxs=connection_interval_maxs,
|
||||
max_latencies=max_latencies,
|
||||
supervision_timeouts=supervision_timeouts,
|
||||
min_ce_lengths=min_ce_lengths,
|
||||
max_ce_lengths=max_ce_lengths,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if hci.HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||
raise InvalidArgumentError('1M PHY preferences required')
|
||||
|
||||
prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY]
|
||||
await self.send_async_command(
|
||||
hci.HCI_LE_Create_Connection_Command(
|
||||
le_scan_interval=int(
|
||||
DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625
|
||||
),
|
||||
le_scan_window=int(
|
||||
DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625
|
||||
),
|
||||
initiator_filter_policy=0,
|
||||
peer_address_type=peer_address.address_type,
|
||||
peer_address=peer_address,
|
||||
own_address_type=own_address_type,
|
||||
connection_interval_min=int(
|
||||
prefs.connection_interval_min / 1.25
|
||||
),
|
||||
connection_interval_max=int(
|
||||
prefs.connection_interval_max / 1.25
|
||||
),
|
||||
max_latency=prefs.max_latency,
|
||||
supervision_timeout=int(prefs.supervision_timeout / 10),
|
||||
min_ce_length=int(prefs.min_ce_length / 0.625),
|
||||
max_ce_length=int(prefs.max_ce_length / 0.625),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection(
|
||||
device=self,
|
||||
handle=0,
|
||||
transport=core.PhysicalTransport.BR_EDR,
|
||||
self_address=self.public_address,
|
||||
self_resolvable_address=None,
|
||||
peer_address=peer_address,
|
||||
peer_resolvable_address=None,
|
||||
role=hci.Role.CENTRAL,
|
||||
parameters=Connection.Parameters(0, 0, 0),
|
||||
)
|
||||
|
||||
# TODO: allow passing other settings
|
||||
await self.send_async_command(
|
||||
hci.HCI_Create_Connection_Command(
|
||||
bd_addr=peer_address,
|
||||
packet_type=0xCC18, # FIXME: change
|
||||
page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE,
|
||||
clock_offset=0x0000,
|
||||
allow_role_switch=0x01,
|
||||
reserved=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the connection process to complete
|
||||
if transport == PhysicalTransport.LE:
|
||||
self.le_connecting = True
|
||||
|
||||
if timeout is None:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(pending_connection), timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if transport == PhysicalTransport.LE:
|
||||
await self.send_sync_command(
|
||||
hci.HCI_LE_Create_Connection_Cancel_Command()
|
||||
)
|
||||
else:
|
||||
await self.send_sync_command(
|
||||
hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
|
||||
)
|
||||
|
||||
try:
|
||||
return await utils.cancel_on_event(
|
||||
self, Device.EVENT_FLUSH, pending_connection
|
||||
)
|
||||
except core.ConnectionError as error:
|
||||
raise core.TimeoutError() from error
|
||||
finally:
|
||||
self.remove_listener(self.EVENT_CONNECTION, on_connection)
|
||||
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
|
||||
if transport == PhysicalTransport.LE:
|
||||
self.le_connecting = False
|
||||
self.connect_own_address_type = None
|
||||
else:
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
raise ValueError('invalid transport')
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
@@ -4706,6 +4750,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
Scan for a peer with a resolvable address that can be resolved to a given
|
||||
identity address.
|
||||
"""
|
||||
if self.address_resolver is None:
|
||||
raise InvalidStateError('no resolver')
|
||||
|
||||
# Create a future to wait for an address to be found
|
||||
peer_address = asyncio.get_running_loop().create_future()
|
||||
|
||||
@@ -201,50 +201,51 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
|
||||
value = data[2 : 2 + value_length]
|
||||
typed_value: Any
|
||||
|
||||
if value_type == ValueType.END:
|
||||
break
|
||||
match value_type:
|
||||
case ValueType.END:
|
||||
break
|
||||
|
||||
if value_type in (ValueType.CNVI, ValueType.CNVR):
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = (
|
||||
(((v >> 0) & 0xF) << 12)
|
||||
| (((v >> 4) & 0xF) << 0)
|
||||
| (((v >> 8) & 0xF) << 4)
|
||||
| (((v >> 24) & 0xF) << 8)
|
||||
)
|
||||
elif value_type == ValueType.HARDWARE_INFO:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = HardwareInfo(
|
||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||
)
|
||||
elif value_type in (
|
||||
ValueType.USB_VENDOR_ID,
|
||||
ValueType.USB_PRODUCT_ID,
|
||||
ValueType.DEVICE_REVISION,
|
||||
):
|
||||
(typed_value,) = struct.unpack("<H", value)
|
||||
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
|
||||
typed_value = ModeOfOperation(value[0])
|
||||
elif value_type in (
|
||||
ValueType.BUILD_TYPE,
|
||||
ValueType.BUILD_NUMBER,
|
||||
ValueType.SECURE_BOOT,
|
||||
ValueType.OTP_LOCK,
|
||||
ValueType.API_LOCK,
|
||||
ValueType.DEBUG_LOCK,
|
||||
ValueType.SECURE_BOOT_ENGINE_TYPE,
|
||||
):
|
||||
typed_value = value[0]
|
||||
elif value_type == ValueType.TIMESTAMP:
|
||||
typed_value = Timestamp(value[0], value[1])
|
||||
elif value_type == ValueType.FIRMWARE_BUILD:
|
||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||
elif value_type == ValueType.BLUETOOTH_ADDRESS:
|
||||
typed_value = hci.Address(
|
||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
else:
|
||||
typed_value = value
|
||||
case ValueType.CNVI | ValueType.CNVR:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = (
|
||||
(((v >> 0) & 0xF) << 12)
|
||||
| (((v >> 4) & 0xF) << 0)
|
||||
| (((v >> 8) & 0xF) << 4)
|
||||
| (((v >> 24) & 0xF) << 8)
|
||||
)
|
||||
case ValueType.HARDWARE_INFO:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = HardwareInfo(
|
||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||
)
|
||||
case (
|
||||
ValueType.USB_VENDOR_ID
|
||||
| ValueType.USB_PRODUCT_ID
|
||||
| ValueType.DEVICE_REVISION
|
||||
):
|
||||
(typed_value,) = struct.unpack("<H", value)
|
||||
case ValueType.CURRENT_MODE_OF_OPERATION:
|
||||
typed_value = ModeOfOperation(value[0])
|
||||
case (
|
||||
ValueType.BUILD_TYPE
|
||||
| ValueType.BUILD_NUMBER
|
||||
| ValueType.SECURE_BOOT
|
||||
| ValueType.OTP_LOCK
|
||||
| ValueType.API_LOCK
|
||||
| ValueType.DEBUG_LOCK
|
||||
| ValueType.SECURE_BOOT_ENGINE_TYPE
|
||||
):
|
||||
typed_value = value[0]
|
||||
case ValueType.TIMESTAMP:
|
||||
typed_value = Timestamp(value[0], value[1])
|
||||
case ValueType.FIRMWARE_BUILD:
|
||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||
case ValueType.BLUETOOTH_ADDRESS:
|
||||
typed_value = hci.Address(
|
||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
case _:
|
||||
typed_value = value
|
||||
|
||||
result.append((value_type, typed_value))
|
||||
data = data[2 + value_length :]
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from bumble.gatt import (
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
Characteristic,
|
||||
Service,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAccessService(Service):
|
||||
def __init__(self, device_name, appearance=(0, 0)):
|
||||
device_name_characteristic = Characteristic(
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
device_name.encode('utf-8')[:248],
|
||||
)
|
||||
|
||||
appearance_characteristic = Characteristic(
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
[device_name_characteristic, appearance_characteristic],
|
||||
)
|
||||
197
bumble/hci.py
197
bumble/hci.py
@@ -31,6 +31,7 @@ from typing import (
|
||||
ClassVar,
|
||||
Generic,
|
||||
Literal,
|
||||
SupportsBytes,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
@@ -247,28 +248,6 @@ HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
|
||||
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
|
||||
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
|
||||
|
||||
HCI_VERSION_NAMES = {
|
||||
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
|
||||
HCI_VERSION_BLUETOOTH_CORE_1_1: 'HCI_VERSION_BLUETOOTH_CORE_1_1',
|
||||
HCI_VERSION_BLUETOOTH_CORE_1_2: 'HCI_VERSION_BLUETOOTH_CORE_1_2',
|
||||
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_0_EDR',
|
||||
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR: 'HCI_VERSION_BLUETOOTH_CORE_2_1_EDR',
|
||||
HCI_VERSION_BLUETOOTH_CORE_3_0_HS: 'HCI_VERSION_BLUETOOTH_CORE_3_0_HS',
|
||||
HCI_VERSION_BLUETOOTH_CORE_4_0: 'HCI_VERSION_BLUETOOTH_CORE_4_0',
|
||||
HCI_VERSION_BLUETOOTH_CORE_4_1: 'HCI_VERSION_BLUETOOTH_CORE_4_1',
|
||||
HCI_VERSION_BLUETOOTH_CORE_4_2: 'HCI_VERSION_BLUETOOTH_CORE_4_2',
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0',
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1',
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2',
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
|
||||
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
|
||||
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
|
||||
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
|
||||
}
|
||||
|
||||
LMP_VERSION_NAMES = HCI_VERSION_NAMES
|
||||
|
||||
# HCI Packet types
|
||||
HCI_COMMAND_PACKET = 0x01
|
||||
HCI_ACL_DATA_PACKET = 0x02
|
||||
@@ -387,8 +366,8 @@ HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_V2_EVENT = 0X26
|
||||
HCI_LE_PERIODIC_ADVERTISING_SUBEVENT_DATA_REQUEST_EVENT = 0X27
|
||||
HCI_LE_PERIODIC_ADVERTISING_RESPONSE_REPORT_EVENT = 0X28
|
||||
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT = 0X29
|
||||
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2A
|
||||
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2B
|
||||
HCI_LE_CIS_ESTABLISHED_V2_EVENT = 0x2A
|
||||
HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT = 0x2B
|
||||
HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT = 0x2C
|
||||
HCI_LE_CS_READ_REMOTE_FAE_TABLE_COMPLETE_EVENT = 0x2D
|
||||
HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT = 0x2E
|
||||
@@ -1860,44 +1839,46 @@ class HCI_Object:
|
||||
field_type = field_type['parser']
|
||||
|
||||
# Parse the field
|
||||
if field_type == '*':
|
||||
# The rest of the bytes
|
||||
field_value = data[offset:]
|
||||
return (field_value, len(field_value))
|
||||
if field_type == 'v':
|
||||
# Variable-length bytes field, with 1-byte length at the beginning
|
||||
field_length = data[offset]
|
||||
offset += 1
|
||||
field_value = data[offset : offset + field_length]
|
||||
return (field_value, field_length + 1)
|
||||
if field_type == 1:
|
||||
# 8-bit unsigned
|
||||
return (data[offset], 1)
|
||||
if field_type == -1:
|
||||
# 8-bit signed
|
||||
return (struct.unpack_from('b', data, offset)[0], 1)
|
||||
if field_type == 2:
|
||||
# 16-bit unsigned
|
||||
return (struct.unpack_from('<H', data, offset)[0], 2)
|
||||
if field_type == '>2':
|
||||
# 16-bit unsigned big-endian
|
||||
return (struct.unpack_from('>H', data, offset)[0], 2)
|
||||
if field_type == -2:
|
||||
# 16-bit signed
|
||||
return (struct.unpack_from('<h', data, offset)[0], 2)
|
||||
if field_type == 3:
|
||||
# 24-bit unsigned
|
||||
padded = data[offset : offset + 3] + bytes([0])
|
||||
return (struct.unpack('<I', padded)[0], 3)
|
||||
if field_type == 4:
|
||||
# 32-bit unsigned
|
||||
return (struct.unpack_from('<I', data, offset)[0], 4)
|
||||
if field_type == '>4':
|
||||
# 32-bit unsigned big-endian
|
||||
return (struct.unpack_from('>I', data, offset)[0], 4)
|
||||
if isinstance(field_type, int) and 4 < field_type <= 256:
|
||||
# Byte array (from 5 up to 256 bytes)
|
||||
return (data[offset : offset + field_type], field_type)
|
||||
match field_type:
|
||||
case '*':
|
||||
# The rest of the bytes
|
||||
field_value = data[offset:]
|
||||
return (field_value, len(field_value))
|
||||
case 'v':
|
||||
# Variable-length bytes field, with 1-byte length at the beginning
|
||||
field_length = data[offset]
|
||||
offset += 1
|
||||
field_value = data[offset : offset + field_length]
|
||||
return (field_value, field_length + 1)
|
||||
case 1:
|
||||
# 8-bit unsigned
|
||||
return (data[offset], 1)
|
||||
case -1:
|
||||
# 8-bit signed
|
||||
return (struct.unpack_from('b', data, offset)[0], 1)
|
||||
case 2:
|
||||
# 16-bit unsigned
|
||||
return (struct.unpack_from('<H', data, offset)[0], 2)
|
||||
case '>2':
|
||||
# 16-bit unsigned big-endian
|
||||
return (struct.unpack_from('>H', data, offset)[0], 2)
|
||||
case -2:
|
||||
# 16-bit signed
|
||||
return (struct.unpack_from('<h', data, offset)[0], 2)
|
||||
case 3:
|
||||
# 24-bit unsigned
|
||||
padded = data[offset : offset + 3] + bytes([0])
|
||||
return (struct.unpack('<I', padded)[0], 3)
|
||||
case 4:
|
||||
# 32-bit unsigned
|
||||
return (struct.unpack_from('<I', data, offset)[0], 4)
|
||||
case '>4':
|
||||
# 32-bit unsigned big-endian
|
||||
return (struct.unpack_from('>I', data, offset)[0], 4)
|
||||
case int() if 4 < field_type <= 256:
|
||||
# Byte array (from 5 up to 256 bytes)
|
||||
return (data[offset : offset + field_type], field_type)
|
||||
|
||||
if callable(field_type):
|
||||
new_offset, field_value = field_type(data, offset)
|
||||
return (field_value, new_offset - offset)
|
||||
@@ -1954,60 +1935,58 @@ class HCI_Object:
|
||||
|
||||
# Serialize the field
|
||||
if serializer:
|
||||
field_bytes = serializer(field_value)
|
||||
elif field_type == 1:
|
||||
# 8-bit unsigned
|
||||
field_bytes = bytes([field_value])
|
||||
elif field_type == -1:
|
||||
# 8-bit signed
|
||||
field_bytes = struct.pack('b', field_value)
|
||||
elif field_type == 2:
|
||||
# 16-bit unsigned
|
||||
field_bytes = struct.pack('<H', field_value)
|
||||
elif field_type == '>2':
|
||||
# 16-bit unsigned big-endian
|
||||
field_bytes = struct.pack('>H', field_value)
|
||||
elif field_type == -2:
|
||||
# 16-bit signed
|
||||
field_bytes = struct.pack('<h', field_value)
|
||||
elif field_type == 3:
|
||||
# 24-bit unsigned
|
||||
field_bytes = struct.pack('<I', field_value)[0:3]
|
||||
elif field_type == 4:
|
||||
# 32-bit unsigned
|
||||
field_bytes = struct.pack('<I', field_value)
|
||||
elif field_type == '>4':
|
||||
# 32-bit unsigned big-endian
|
||||
field_bytes = struct.pack('>I', field_value)
|
||||
elif field_type == '*':
|
||||
if isinstance(field_value, int):
|
||||
if 0 <= field_value <= 255:
|
||||
field_bytes = bytes([field_value])
|
||||
return serializer(field_value)
|
||||
match field_type:
|
||||
case 1:
|
||||
# 8-bit unsigned
|
||||
return bytes([field_value])
|
||||
case -1:
|
||||
# 8-bit signed
|
||||
return struct.pack('b', field_value)
|
||||
case 2:
|
||||
# 16-bit unsigned
|
||||
return struct.pack('<H', field_value)
|
||||
case '>2':
|
||||
# 16-bit unsigned big-endian
|
||||
return struct.pack('>H', field_value)
|
||||
case -2:
|
||||
# 16-bit signed
|
||||
return struct.pack('<h', field_value)
|
||||
case 3:
|
||||
# 24-bit unsigned
|
||||
return struct.pack('<I', field_value)[0:3]
|
||||
case 4:
|
||||
# 32-bit unsigned
|
||||
return struct.pack('<I', field_value)
|
||||
case '>4':
|
||||
# 32-bit unsigned big-endian
|
||||
return struct.pack('>I', field_value)
|
||||
case '*':
|
||||
if isinstance(field_value, int):
|
||||
if 0 <= field_value <= 255:
|
||||
return bytes([field_value])
|
||||
else:
|
||||
raise InvalidArgumentError('value too large for *-typed field')
|
||||
else:
|
||||
raise InvalidArgumentError('value too large for *-typed field')
|
||||
else:
|
||||
return bytes(field_value)
|
||||
case 'v':
|
||||
# Variable-length bytes field, with 1-byte length at the beginning
|
||||
field_bytes = bytes(field_value)
|
||||
elif field_type == 'v':
|
||||
# Variable-length bytes field, with 1-byte length at the beginning
|
||||
field_bytes = bytes(field_value)
|
||||
field_length = len(field_bytes)
|
||||
field_bytes = bytes([field_length]) + field_bytes
|
||||
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
|
||||
field_value, '__bytes__'
|
||||
):
|
||||
field_length = len(field_bytes)
|
||||
return bytes([field_length]) + field_bytes
|
||||
if isinstance(field_value, (bytes, bytearray, SupportsBytes)):
|
||||
field_bytes = bytes(field_value)
|
||||
if isinstance(field_type, int) and 4 < field_type <= 256:
|
||||
# Truncate or pad with zeros if the field is too long or too short
|
||||
if len(field_bytes) < field_type:
|
||||
field_bytes += bytes(field_type - len(field_bytes))
|
||||
return field_bytes + bytes(field_type - len(field_bytes))
|
||||
elif len(field_bytes) > field_type:
|
||||
field_bytes = field_bytes[:field_type]
|
||||
else:
|
||||
raise InvalidArgumentError(
|
||||
f"don't know how to serialize type {type(field_value)}"
|
||||
)
|
||||
return field_bytes[:field_type]
|
||||
return field_bytes
|
||||
|
||||
return field_bytes
|
||||
raise InvalidArgumentError(
|
||||
f"don't know how to serialize type {type(field_value)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dict_to_bytes(hci_object, object_fields):
|
||||
@@ -4736,7 +4715,7 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_SyncCommand[HCI_StatusReturnParame
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class HCI_LE_Read_Resolving_List_Size_ReturnParameters(HCI_StatusReturnParameters):
|
||||
resolving_list_size: bytes = field(metadata=metadata(1))
|
||||
resolving_list_size: int = field(metadata=metadata(1))
|
||||
|
||||
|
||||
@HCI_SyncCommand.sync_command(HCI_LE_Read_Resolving_List_Size_ReturnParameters)
|
||||
|
||||
149
bumble/hfp.py
149
bumble/hfp.py
@@ -26,7 +26,7 @@ import logging
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from typing import Any, ClassVar, Literal, overload
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -420,61 +420,6 @@ class CmeError(enum.IntEnum):
|
||||
# Hands-Free Control Interoperability Requirements
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Response codes.
|
||||
RESPONSE_CODES = {
|
||||
"+APLSIRI",
|
||||
"+BAC",
|
||||
"+BCC",
|
||||
"+BCS",
|
||||
"+BIA",
|
||||
"+BIEV",
|
||||
"+BIND",
|
||||
"+BINP",
|
||||
"+BLDN",
|
||||
"+BRSF",
|
||||
"+BTRH",
|
||||
"+BVRA",
|
||||
"+CCWA",
|
||||
"+CHLD",
|
||||
"+CHUP",
|
||||
"+CIND",
|
||||
"+CLCC",
|
||||
"+CLIP",
|
||||
"+CMEE",
|
||||
"+CMER",
|
||||
"+CNUM",
|
||||
"+COPS",
|
||||
"+IPHONEACCEV",
|
||||
"+NREC",
|
||||
"+VGM",
|
||||
"+VGS",
|
||||
"+VTS",
|
||||
"+XAPL",
|
||||
"A",
|
||||
"D",
|
||||
}
|
||||
|
||||
# Unsolicited responses and statuses.
|
||||
UNSOLICITED_CODES = {
|
||||
"+APLSIRI",
|
||||
"+BCS",
|
||||
"+BIND",
|
||||
"+BSIR",
|
||||
"+BTRH",
|
||||
"+BVRA",
|
||||
"+CCWA",
|
||||
"+CIEV",
|
||||
"+CLIP",
|
||||
"+VGM",
|
||||
"+VGS",
|
||||
"BLACKLISTED",
|
||||
"BUSY",
|
||||
"DELAYED",
|
||||
"NO ANSWER",
|
||||
"NO CARRIER",
|
||||
"RING",
|
||||
}
|
||||
|
||||
# Status codes
|
||||
STATUS_CODES = {
|
||||
"+CME ERROR",
|
||||
@@ -727,12 +672,9 @@ class HfProtocol(utils.EventEmitter):
|
||||
|
||||
dlc: rfcomm.DLC
|
||||
command_lock: asyncio.Lock
|
||||
if TYPE_CHECKING:
|
||||
response_queue: asyncio.Queue[AtResponse]
|
||||
unsolicited_queue: asyncio.Queue[AtResponse | None]
|
||||
else:
|
||||
response_queue: asyncio.Queue
|
||||
unsolicited_queue: asyncio.Queue
|
||||
pending_command: str | None = None
|
||||
response_queue: asyncio.Queue[AtResponse]
|
||||
unsolicited_queue: asyncio.Queue[AtResponse | None]
|
||||
read_buffer: bytearray
|
||||
active_codec: AudioCodec
|
||||
|
||||
@@ -805,16 +747,39 @@ class HfProtocol(utils.EventEmitter):
|
||||
self.read_buffer = self.read_buffer[trailer + 2 :]
|
||||
|
||||
# Forward the received code to the correct queue.
|
||||
if self.command_lock.locked() and (
|
||||
response.code in STATUS_CODES or response.code in RESPONSE_CODES
|
||||
if self.pending_command and (
|
||||
response.code in STATUS_CODES or response.code in self.pending_command
|
||||
):
|
||||
self.response_queue.put_nowait(response)
|
||||
elif response.code in UNSOLICITED_CODES:
|
||||
self.unsolicited_queue.put_nowait(response)
|
||||
else:
|
||||
logger.warning(
|
||||
f"dropping unexpected response with code '{response.code}'"
|
||||
)
|
||||
self.unsolicited_queue.put_nowait(response)
|
||||
|
||||
@overload
|
||||
async def execute_command(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: float = 1.0,
|
||||
*,
|
||||
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
async def execute_command(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: float = 1.0,
|
||||
*,
|
||||
response_type: Literal[AtResponseType.SINGLE],
|
||||
) -> AtResponse: ...
|
||||
|
||||
@overload
|
||||
async def execute_command(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: float = 1.0,
|
||||
*,
|
||||
response_type: Literal[AtResponseType.MULTIPLE],
|
||||
) -> list[AtResponse]: ...
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
@@ -835,27 +800,34 @@ class HfProtocol(utils.EventEmitter):
|
||||
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
|
||||
ProtocolError: the status is not OK.
|
||||
"""
|
||||
async with self.command_lock:
|
||||
logger.debug(f">>> {cmd}")
|
||||
self.dlc.write(cmd + '\r')
|
||||
responses: list[AtResponse] = []
|
||||
try:
|
||||
async with self.command_lock:
|
||||
self.pending_command = cmd
|
||||
logger.debug(f">>> {cmd}")
|
||||
self.dlc.write(cmd + '\r')
|
||||
responses: list[AtResponse] = []
|
||||
|
||||
while True:
|
||||
result = await asyncio.wait_for(
|
||||
self.response_queue.get(), timeout=timeout
|
||||
)
|
||||
if result.code == 'OK':
|
||||
if response_type == AtResponseType.SINGLE and len(responses) != 1:
|
||||
raise HfpProtocolError("NO ANSWER")
|
||||
while True:
|
||||
result = await asyncio.wait_for(
|
||||
self.response_queue.get(), timeout=timeout
|
||||
)
|
||||
if result.code == 'OK':
|
||||
if (
|
||||
response_type == AtResponseType.SINGLE
|
||||
and len(responses) != 1
|
||||
):
|
||||
raise HfpProtocolError("NO ANSWER")
|
||||
|
||||
if response_type == AtResponseType.MULTIPLE:
|
||||
return responses
|
||||
if response_type == AtResponseType.SINGLE:
|
||||
return responses[0]
|
||||
return None
|
||||
if result.code in STATUS_CODES:
|
||||
raise HfpProtocolError(result.code)
|
||||
responses.append(result)
|
||||
if response_type == AtResponseType.MULTIPLE:
|
||||
return responses
|
||||
if response_type == AtResponseType.SINGLE:
|
||||
return responses[0]
|
||||
return None
|
||||
if result.code in STATUS_CODES:
|
||||
raise HfpProtocolError(result.code)
|
||||
responses.append(result)
|
||||
finally:
|
||||
self.pending_command = None
|
||||
|
||||
async def initiate_slc(self):
|
||||
"""4.2.1 Service Level Connection Initialization."""
|
||||
@@ -1067,7 +1039,6 @@ class HfProtocol(utils.EventEmitter):
|
||||
responses = await self.execute_command(
|
||||
"AT+CLCC", response_type=AtResponseType.MULTIPLE
|
||||
)
|
||||
assert isinstance(responses, list)
|
||||
|
||||
calls = []
|
||||
for response in responses:
|
||||
|
||||
@@ -22,7 +22,7 @@ import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
from bumble import drivers, hci, utils
|
||||
from bumble.colors import color
|
||||
@@ -616,22 +616,28 @@ class Host(utils.EventEmitter):
|
||||
if self.supports_command(
|
||||
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
|
||||
):
|
||||
response10 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
|
||||
)
|
||||
self.number_of_supported_advertising_sets = (
|
||||
response10.num_supported_advertising_sets
|
||||
)
|
||||
try:
|
||||
response10 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
|
||||
)
|
||||
self.number_of_supported_advertising_sets = (
|
||||
response10.num_supported_advertising_sets
|
||||
)
|
||||
except hci.HCI_Error:
|
||||
logger.warning('Failed to read number of supported advertising sets')
|
||||
|
||||
if self.supports_command(
|
||||
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
|
||||
):
|
||||
response11 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
|
||||
)
|
||||
self.maximum_advertising_data_length = (
|
||||
response11.max_advertising_data_length
|
||||
)
|
||||
try:
|
||||
response11 = await self.send_sync_command(
|
||||
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
|
||||
)
|
||||
self.maximum_advertising_data_length = (
|
||||
response11.max_advertising_data_length
|
||||
)
|
||||
except hci.HCI_Error:
|
||||
logger.warning('Failed to read maximum advertising data length')
|
||||
|
||||
@property
|
||||
def controller(self) -> TransportSink | None:
|
||||
@@ -776,6 +782,20 @@ class Host(utils.EventEmitter):
|
||||
) -> hci.HCI_Command_Complete_Event[_RP]:
|
||||
response = await self._send_command(command, response_timeout)
|
||||
|
||||
# For unknown HCI commands, some controllers return Command Status instead of
|
||||
# Command Complete.
|
||||
if (
|
||||
isinstance(response, hci.HCI_Command_Status_Event)
|
||||
and response.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
|
||||
):
|
||||
return hci.HCI_Command_Complete_Event(
|
||||
num_hci_command_packets=response.num_hci_command_packets,
|
||||
command_opcode=command.op_code,
|
||||
return_parameters=hci.HCI_StatusReturnParameters(
|
||||
status=hci.HCI_ErrorCode(response.status)
|
||||
), # type: ignore
|
||||
)
|
||||
|
||||
# Check that the response is of the expected type
|
||||
assert isinstance(response, hci.HCI_Command_Complete_Event)
|
||||
|
||||
@@ -789,19 +809,25 @@ class Host(utils.EventEmitter):
|
||||
) -> hci.HCI_ErrorCode:
|
||||
response = await self._send_command(command, response_timeout)
|
||||
|
||||
# Check that the response is of the expected type
|
||||
assert isinstance(response, hci.HCI_Command_Status_Event)
|
||||
# For unknown HCI commands, some controllers return Command Complete instead of
|
||||
# Command Status.
|
||||
if isinstance(response, hci.HCI_Command_Complete_Event):
|
||||
# Assume the first byte of the return parameters is the status
|
||||
if (
|
||||
status := hci.HCI_ErrorCode(response.parameters[3])
|
||||
) != hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR:
|
||||
logger.warning(f'unexpected return paramerers status {status}')
|
||||
else:
|
||||
assert isinstance(response, hci.HCI_Command_Status_Event)
|
||||
status = hci.HCI_ErrorCode(response.status)
|
||||
|
||||
# Check the return parameters if required
|
||||
status = response.status
|
||||
# Check the status if required
|
||||
if check_status:
|
||||
if status != hci.HCI_CommandStatus.PENDING:
|
||||
logger.warning(
|
||||
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
|
||||
)
|
||||
logger.warning(f'{command.name} failed ' f'({status.name})')
|
||||
raise hci.HCI_Error(status)
|
||||
|
||||
return hci.HCI_ErrorCode(status)
|
||||
return status
|
||||
|
||||
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
|
||||
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
|
||||
@@ -830,7 +856,9 @@ class Host(utils.EventEmitter):
|
||||
data=pdu,
|
||||
)
|
||||
logger.debug(
|
||||
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu
|
||||
'>>> ACL packet enqueue: (handle=0x%04X) %s',
|
||||
connection_handle,
|
||||
pdu.hex(),
|
||||
)
|
||||
packet_queue.enqueue(acl_packet, connection_handle)
|
||||
|
||||
@@ -974,18 +1002,19 @@ class Host(utils.EventEmitter):
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
|
||||
# If the packet is a command, invoke the handler for this packet
|
||||
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.on_hci_command_packet(cast(hci.HCI_Command, packet))
|
||||
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
|
||||
self.on_hci_event_packet(cast(hci.HCI_Event, packet))
|
||||
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
|
||||
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
|
||||
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
|
||||
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
|
||||
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
|
||||
else:
|
||||
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
|
||||
match packet:
|
||||
case hci.HCI_Command():
|
||||
self.on_hci_command_packet(packet)
|
||||
case hci.HCI_Event():
|
||||
self.on_hci_event_packet(packet)
|
||||
case hci.HCI_AclDataPacket():
|
||||
self.on_hci_acl_data_packet(packet)
|
||||
case hci.HCI_SynchronousDataPacket():
|
||||
self.on_hci_sco_data_packet(packet)
|
||||
case hci.HCI_IsoDataPacket():
|
||||
self.on_hci_iso_data_packet(packet)
|
||||
case _:
|
||||
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
|
||||
|
||||
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
|
||||
logger.warning(f'!!! unexpected command packet: {command}')
|
||||
|
||||
21
bumble/ll.py
21
bumble/ll.py
@@ -198,3 +198,24 @@ class CisTerminateInd(ControlPdu):
|
||||
cig_id: int
|
||||
cis_id: int
|
||||
error_code: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FeatureReq(ControlPdu):
|
||||
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
|
||||
|
||||
feature_set: bytes
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FeatureRsp(ControlPdu):
|
||||
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
|
||||
|
||||
feature_set: bytes
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PeripheralFeatureReq(ControlPdu):
|
||||
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
|
||||
|
||||
feature_set: bytes
|
||||
|
||||
@@ -21,18 +21,9 @@ import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bumble import hci
|
||||
from bumble import hci, smp
|
||||
from bumble.core import AdvertisingData, LeRole
|
||||
from bumble.smp import (
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
SMP_ENC_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
OobSharedData,
|
||||
@@ -96,11 +87,11 @@ class PairingDelegate:
|
||||
# These are defined abstractly, and can be mapped to specific Classic pairing
|
||||
# and/or SMP constants.
|
||||
class IoCapability(enum.IntEnum):
|
||||
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
|
||||
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
|
||||
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
|
||||
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
|
||||
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT
|
||||
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY
|
||||
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY
|
||||
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO
|
||||
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY
|
||||
|
||||
# Direct names for backward compatibility.
|
||||
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
|
||||
@@ -111,10 +102,10 @@ class PairingDelegate:
|
||||
|
||||
# Key Distribution [LE only]
|
||||
class KeyDistribution(enum.IntFlag):
|
||||
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY
|
||||
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY
|
||||
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY
|
||||
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY
|
||||
|
||||
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
|
||||
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
|
||||
|
||||
@@ -664,46 +664,44 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
responses = []
|
||||
logger.debug(f'*** ASCS Write {operation} ***')
|
||||
|
||||
if isinstance(operation, ASE_Config_Codec):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.target_latency,
|
||||
operation.target_phy,
|
||||
operation.codec_id,
|
||||
operation.codec_specific_configuration,
|
||||
match operation:
|
||||
case ASE_Config_Codec():
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.target_latency,
|
||||
operation.target_phy,
|
||||
operation.codec_id,
|
||||
operation.codec_specific_configuration,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
case ASE_Config_QOS():
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.cig_id,
|
||||
operation.cis_id,
|
||||
operation.sdu_interval,
|
||||
operation.framing,
|
||||
operation.phy,
|
||||
operation.max_sdu,
|
||||
operation.retransmission_number,
|
||||
operation.max_transport_latency,
|
||||
operation.presentation_delay,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
case ASE_Enable() | ASE_Update_Metadata():
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.metadata,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
case (
|
||||
ASE_Receiver_Start_Ready()
|
||||
| ASE_Disable()
|
||||
| ASE_Receiver_Stop_Ready()
|
||||
| ASE_Release()
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif isinstance(operation, ASE_Config_QOS):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.cig_id,
|
||||
operation.cis_id,
|
||||
operation.sdu_interval,
|
||||
operation.framing,
|
||||
operation.phy,
|
||||
operation.max_sdu,
|
||||
operation.retransmission_number,
|
||||
operation.max_transport_latency,
|
||||
operation.presentation_delay,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.metadata,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif isinstance(
|
||||
operation,
|
||||
(
|
||||
ASE_Receiver_Start_Ready,
|
||||
ASE_Disable,
|
||||
ASE_Receiver_Stop_Ready,
|
||||
ASE_Release,
|
||||
),
|
||||
):
|
||||
for ase_id in operation.ase_id:
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||
for ase_id in operation.ase_id:
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||
|
||||
control_point_notification = bytes(
|
||||
[operation.op_code, len(responses)]
|
||||
|
||||
@@ -333,17 +333,18 @@ class CodecSpecificCapabilities:
|
||||
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
||||
offset += length - 1
|
||||
|
||||
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
|
||||
supported_sampling_frequencies = SupportedSamplingFrequency(value)
|
||||
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
|
||||
supported_frame_durations = SupportedFrameDuration(value)
|
||||
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
||||
supported_audio_channel_count = bits_to_channel_counts(value)
|
||||
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
||||
min_octets_per_sample = value & 0xFFFF
|
||||
max_octets_per_sample = value >> 16
|
||||
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
|
||||
supported_max_codec_frames_per_sdu = value
|
||||
match type:
|
||||
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
|
||||
supported_sampling_frequencies = SupportedSamplingFrequency(value)
|
||||
case CodecSpecificCapabilities.Type.FRAME_DURATION:
|
||||
supported_frame_durations = SupportedFrameDuration(value)
|
||||
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
||||
supported_audio_channel_count = bits_to_channel_counts(value)
|
||||
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
||||
min_octets_per_sample = value & 0xFFFF
|
||||
max_octets_per_sample = value >> 16
|
||||
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
|
||||
supported_max_codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
|
||||
@@ -55,14 +55,15 @@ class GenericAccessService(TemplateService):
|
||||
def __init__(
|
||||
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
|
||||
):
|
||||
if isinstance(appearance, int):
|
||||
appearance_int = appearance
|
||||
elif isinstance(appearance, tuple):
|
||||
appearance_int = (appearance[0] << 6) | appearance[1]
|
||||
elif isinstance(appearance, Appearance):
|
||||
appearance_int = int(appearance)
|
||||
else:
|
||||
raise TypeError()
|
||||
match appearance:
|
||||
case int():
|
||||
appearance_int = appearance
|
||||
case tuple():
|
||||
appearance_int = (appearance[0] << 6) | appearance[1]
|
||||
case Appearance():
|
||||
appearance_int = int(appearance)
|
||||
case _:
|
||||
raise TypeError()
|
||||
|
||||
self.device_name_characteristic = Characteristic(
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
|
||||
521
bumble/smp.py
521
bumble/smp.py
@@ -27,18 +27,17 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
|
||||
|
||||
from bumble import crypto, utils
|
||||
from bumble import crypto, hci, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
InvalidArgumentError,
|
||||
PhysicalTransport,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
@@ -46,7 +45,6 @@ from bumble.hci import (
|
||||
HCI_LE_Enable_Encryption_Command,
|
||||
HCI_Object,
|
||||
Role,
|
||||
key_with_value,
|
||||
metadata,
|
||||
)
|
||||
from bumble.keys import PairingKeys
|
||||
@@ -71,110 +69,110 @@ logger = logging.getLogger(__name__)
|
||||
SMP_CID = 0x06
|
||||
SMP_BR_CID = 0x07
|
||||
|
||||
SMP_PAIRING_REQUEST_COMMAND = 0x01
|
||||
SMP_PAIRING_RESPONSE_COMMAND = 0x02
|
||||
SMP_PAIRING_CONFIRM_COMMAND = 0x03
|
||||
SMP_PAIRING_RANDOM_COMMAND = 0x04
|
||||
SMP_PAIRING_FAILED_COMMAND = 0x05
|
||||
SMP_ENCRYPTION_INFORMATION_COMMAND = 0x06
|
||||
SMP_MASTER_IDENTIFICATION_COMMAND = 0x07
|
||||
SMP_IDENTITY_INFORMATION_COMMAND = 0x08
|
||||
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND = 0x09
|
||||
SMP_SIGNING_INFORMATION_COMMAND = 0x0A
|
||||
SMP_SECURITY_REQUEST_COMMAND = 0x0B
|
||||
SMP_PAIRING_PUBLIC_KEY_COMMAND = 0x0C
|
||||
SMP_PAIRING_DHKEY_CHECK_COMMAND = 0x0D
|
||||
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND = 0x0E
|
||||
class CommandCode(hci.SpecableEnum):
|
||||
PAIRING_REQUEST = 0x01
|
||||
PAIRING_RESPONSE = 0x02
|
||||
PAIRING_CONFIRM = 0x03
|
||||
PAIRING_RANDOM = 0x04
|
||||
PAIRING_FAILED = 0x05
|
||||
ENCRYPTION_INFORMATION = 0x06
|
||||
MASTER_IDENTIFICATION = 0x07
|
||||
IDENTITY_INFORMATION = 0x08
|
||||
IDENTITY_ADDRESS_INFORMATION = 0x09
|
||||
SIGNING_INFORMATION = 0x0A
|
||||
SECURITY_REQUEST = 0x0B
|
||||
PAIRING_PUBLIC_KEY = 0x0C
|
||||
PAIRING_DHKEY_CHECK = 0x0D
|
||||
PAIRING_KEYPRESS_NOTIFICATION = 0x0E
|
||||
|
||||
SMP_COMMAND_NAMES = {
|
||||
SMP_PAIRING_REQUEST_COMMAND: 'SMP_PAIRING_REQUEST_COMMAND',
|
||||
SMP_PAIRING_RESPONSE_COMMAND: 'SMP_PAIRING_RESPONSE_COMMAND',
|
||||
SMP_PAIRING_CONFIRM_COMMAND: 'SMP_PAIRING_CONFIRM_COMMAND',
|
||||
SMP_PAIRING_RANDOM_COMMAND: 'SMP_PAIRING_RANDOM_COMMAND',
|
||||
SMP_PAIRING_FAILED_COMMAND: 'SMP_PAIRING_FAILED_COMMAND',
|
||||
SMP_ENCRYPTION_INFORMATION_COMMAND: 'SMP_ENCRYPTION_INFORMATION_COMMAND',
|
||||
SMP_MASTER_IDENTIFICATION_COMMAND: 'SMP_MASTER_IDENTIFICATION_COMMAND',
|
||||
SMP_IDENTITY_INFORMATION_COMMAND: 'SMP_IDENTITY_INFORMATION_COMMAND',
|
||||
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND: 'SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND',
|
||||
SMP_SIGNING_INFORMATION_COMMAND: 'SMP_SIGNING_INFORMATION_COMMAND',
|
||||
SMP_SECURITY_REQUEST_COMMAND: 'SMP_SECURITY_REQUEST_COMMAND',
|
||||
SMP_PAIRING_PUBLIC_KEY_COMMAND: 'SMP_PAIRING_PUBLIC_KEY_COMMAND',
|
||||
SMP_PAIRING_DHKEY_CHECK_COMMAND: 'SMP_PAIRING_DHKEY_CHECK_COMMAND',
|
||||
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND: 'SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND'
|
||||
}
|
||||
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY = 0x00
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY = 0x01
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY = 0x02
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = 0x03
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = 0x04
|
||||
class IoCapability(hci.SpecableEnum):
|
||||
DISPLAY_ONLY = 0x00
|
||||
DISPLAY_YES_NO = 0x01
|
||||
KEYBOARD_ONLY = 0x02
|
||||
NO_INPUT_NO_OUTPUT = 0x03
|
||||
KEYBOARD_DISPLAY = 0x04
|
||||
|
||||
SMP_IO_CAPABILITY_NAMES = {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: 'SMP_DISPLAY_ONLY_IO_CAPABILITY',
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: 'SMP_DISPLAY_YES_NO_IO_CAPABILITY',
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: 'SMP_KEYBOARD_ONLY_IO_CAPABILITY',
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: 'SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY',
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: 'SMP_KEYBOARD_DISPLAY_IO_CAPABILITY'
|
||||
}
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY = IoCapability.DISPLAY_ONLY
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY = IoCapability.DISPLAY_YES_NO
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY = IoCapability.KEYBOARD_ONLY
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = IoCapability.NO_INPUT_NO_OUTPUT
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = IoCapability.KEYBOARD_DISPLAY
|
||||
|
||||
SMP_PASSKEY_ENTRY_FAILED_ERROR = 0x01
|
||||
SMP_OOB_NOT_AVAILABLE_ERROR = 0x02
|
||||
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = 0x03
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR = 0x04
|
||||
SMP_PAIRING_NOT_SUPPORTED_ERROR = 0x05
|
||||
SMP_ENCRYPTION_KEY_SIZE_ERROR = 0x06
|
||||
SMP_COMMAND_NOT_SUPPORTED_ERROR = 0x07
|
||||
SMP_UNSPECIFIED_REASON_ERROR = 0x08
|
||||
SMP_REPEATED_ATTEMPTS_ERROR = 0x09
|
||||
SMP_INVALID_PARAMETERS_ERROR = 0x0A
|
||||
SMP_DHKEY_CHECK_FAILED_ERROR = 0x0B
|
||||
SMP_NUMERIC_COMPARISON_FAILED_ERROR = 0x0C
|
||||
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = 0x0D
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = 0x0E
|
||||
class ErrorCode(hci.SpecableEnum):
|
||||
PASSKEY_ENTRY_FAILED = 0x01
|
||||
OOB_NOT_AVAILABLE = 0x02
|
||||
AUTHENTICATION_REQUIREMENTS = 0x03
|
||||
CONFIRM_VALUE_FAILED = 0x04
|
||||
PAIRING_NOT_SUPPORTED = 0x05
|
||||
ENCRYPTION_KEY_SIZE = 0x06
|
||||
COMMAND_NOT_SUPPORTED = 0x07
|
||||
UNSPECIFIED_REASON = 0x08
|
||||
REPEATED_ATTEMPTS = 0x09
|
||||
INVALID_PARAMETERS = 0x0A
|
||||
DHKEY_CHECK_FAILED = 0x0B
|
||||
NUMERIC_COMPARISON_FAILED = 0x0C
|
||||
BD_EDR_PAIRING_IN_PROGRESS = 0x0D
|
||||
CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED = 0x0E
|
||||
|
||||
SMP_ERROR_NAMES = {
|
||||
SMP_PASSKEY_ENTRY_FAILED_ERROR: 'SMP_PASSKEY_ENTRY_FAILED_ERROR',
|
||||
SMP_OOB_NOT_AVAILABLE_ERROR: 'SMP_OOB_NOT_AVAILABLE_ERROR',
|
||||
SMP_AUTHENTICATION_REQUIREMENTS_ERROR: 'SMP_AUTHENTICATION_REQUIREMENTS_ERROR',
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR: 'SMP_CONFIRM_VALUE_FAILED_ERROR',
|
||||
SMP_PAIRING_NOT_SUPPORTED_ERROR: 'SMP_PAIRING_NOT_SUPPORTED_ERROR',
|
||||
SMP_ENCRYPTION_KEY_SIZE_ERROR: 'SMP_ENCRYPTION_KEY_SIZE_ERROR',
|
||||
SMP_COMMAND_NOT_SUPPORTED_ERROR: 'SMP_COMMAND_NOT_SUPPORTED_ERROR',
|
||||
SMP_UNSPECIFIED_REASON_ERROR: 'SMP_UNSPECIFIED_REASON_ERROR',
|
||||
SMP_REPEATED_ATTEMPTS_ERROR: 'SMP_REPEATED_ATTEMPTS_ERROR',
|
||||
SMP_INVALID_PARAMETERS_ERROR: 'SMP_INVALID_PARAMETERS_ERROR',
|
||||
SMP_DHKEY_CHECK_FAILED_ERROR: 'SMP_DHKEY_CHECK_FAILED_ERROR',
|
||||
SMP_NUMERIC_COMPARISON_FAILED_ERROR: 'SMP_NUMERIC_COMPARISON_FAILED_ERROR',
|
||||
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR: 'SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR',
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR: 'SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR'
|
||||
}
|
||||
SMP_PASSKEY_ENTRY_FAILED_ERROR = ErrorCode.PASSKEY_ENTRY_FAILED
|
||||
SMP_OOB_NOT_AVAILABLE_ERROR = ErrorCode.OOB_NOT_AVAILABLE
|
||||
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = ErrorCode.AUTHENTICATION_REQUIREMENTS
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR = ErrorCode.CONFIRM_VALUE_FAILED
|
||||
SMP_PAIRING_NOT_SUPPORTED_ERROR = ErrorCode.PAIRING_NOT_SUPPORTED
|
||||
SMP_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.ENCRYPTION_KEY_SIZE
|
||||
SMP_COMMAND_NOT_SUPPORTED_ERROR = ErrorCode.COMMAND_NOT_SUPPORTED
|
||||
SMP_UNSPECIFIED_REASON_ERROR = ErrorCode.UNSPECIFIED_REASON
|
||||
SMP_REPEATED_ATTEMPTS_ERROR = ErrorCode.REPEATED_ATTEMPTS
|
||||
SMP_INVALID_PARAMETERS_ERROR = ErrorCode.INVALID_PARAMETERS
|
||||
SMP_DHKEY_CHECK_FAILED_ERROR = ErrorCode.DHKEY_CHECK_FAILED
|
||||
SMP_NUMERIC_COMPARISON_FAILED_ERROR = ErrorCode.NUMERIC_COMPARISON_FAILED
|
||||
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = ErrorCode.BD_EDR_PAIRING_IN_PROGRESS
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
|
||||
|
||||
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE = 0
|
||||
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE = 1
|
||||
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE = 2
|
||||
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE = 3
|
||||
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE = 4
|
||||
|
||||
SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES = {
|
||||
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE',
|
||||
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE',
|
||||
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE',
|
||||
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE',
|
||||
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE'
|
||||
}
|
||||
class KeypressNotificationType(hci.SpecableEnum):
|
||||
PASSKEY_ENTRY_STARTED = 0
|
||||
PASSKEY_DIGIT_ENTERED = 1
|
||||
PASSKEY_DIGIT_ERASED = 2
|
||||
PASSKEY_CLEARED = 3
|
||||
PASSKEY_ENTRY_COMPLETED = 4
|
||||
|
||||
# Bit flags for key distribution/generation
|
||||
SMP_ENC_KEY_DISTRIBUTION_FLAG = 0b0001
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG = 0b0010
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG = 0b0100
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG = 0b1000
|
||||
class KeyDistribution(hci.SpecableFlag):
|
||||
ENC_KEY = 0b0001
|
||||
ID_KEY = 0b0010
|
||||
SIGN_KEY = 0b0100
|
||||
LINK_KEY = 0b1000
|
||||
|
||||
# AuthReq fields
|
||||
SMP_BONDING_AUTHREQ = 0b00000001
|
||||
SMP_MITM_AUTHREQ = 0b00000100
|
||||
SMP_SC_AUTHREQ = 0b00001000
|
||||
SMP_KEYPRESS_AUTHREQ = 0b00010000
|
||||
SMP_CT2_AUTHREQ = 0b00100000
|
||||
class AuthReq(hci.SpecableFlag):
|
||||
BONDING = 0b00000001
|
||||
MITM = 0b00000100
|
||||
SC = 0b00001000
|
||||
KEYPRESS = 0b00010000
|
||||
CT2 = 0b00100000
|
||||
|
||||
@classmethod
|
||||
def from_booleans(
|
||||
cls,
|
||||
bonding: bool = False,
|
||||
sc: bool = False,
|
||||
mitm: bool = False,
|
||||
keypress: bool = False,
|
||||
ct2: bool = False,
|
||||
) -> AuthReq:
|
||||
auth_req = AuthReq(0)
|
||||
if bonding:
|
||||
auth_req |= AuthReq.BONDING
|
||||
if sc:
|
||||
auth_req |= AuthReq.SC
|
||||
if mitm:
|
||||
auth_req |= AuthReq.MITM
|
||||
if keypress:
|
||||
auth_req |= AuthReq.KEYPRESS
|
||||
if ct2:
|
||||
auth_req |= AuthReq.CT2
|
||||
return auth_req
|
||||
|
||||
# Crypto salt
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
||||
@@ -188,8 +186,6 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
def error_name(error_code: int) -> str:
|
||||
return name_or_number(SMP_ERROR_NAMES, error_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -201,20 +197,20 @@ class SMP_Command:
|
||||
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
|
||||
'''
|
||||
|
||||
smp_classes: ClassVar[dict[int, type[SMP_Command]]] = {}
|
||||
smp_classes: ClassVar[dict[CommandCode, type[SMP_Command]]] = {}
|
||||
fields: ClassVar[Fields]
|
||||
code: int = field(default=0, init=False)
|
||||
code: CommandCode = field(default=CommandCode(0), init=False)
|
||||
name: str = field(default='', init=False)
|
||||
_payload: bytes | None = field(default=None, init=False)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, pdu: bytes) -> SMP_Command:
|
||||
code = pdu[0]
|
||||
code = CommandCode(pdu[0])
|
||||
|
||||
subclass = SMP_Command.smp_classes.get(code)
|
||||
if subclass is None:
|
||||
instance = SMP_Command()
|
||||
instance.name = SMP_Command.command_name(code)
|
||||
instance.name = code.name
|
||||
instance.code = code
|
||||
instance.payload = pdu
|
||||
return instance
|
||||
@@ -222,59 +218,14 @@ class SMP_Command:
|
||||
instance.payload = pdu[1:]
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def command_name(code: int) -> str:
|
||||
return name_or_number(SMP_COMMAND_NAMES, code)
|
||||
|
||||
@staticmethod
|
||||
def auth_req_str(value: int) -> str:
|
||||
bonding_flags = value & 3
|
||||
mitm = (value >> 2) & 1
|
||||
sc = (value >> 3) & 1
|
||||
keypress = (value >> 4) & 1
|
||||
ct2 = (value >> 5) & 1
|
||||
|
||||
return (
|
||||
f'bonding_flags={bonding_flags}, '
|
||||
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def io_capability_name(io_capability: int) -> str:
|
||||
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
|
||||
|
||||
@staticmethod
|
||||
def key_distribution_str(value: int) -> str:
|
||||
key_types: list[str] = []
|
||||
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('ENC')
|
||||
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('ID')
|
||||
if value & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('SIGN')
|
||||
if value & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('LINK')
|
||||
return ','.join(key_types)
|
||||
|
||||
@staticmethod
|
||||
def keypress_notification_type_name(notification_type: int) -> str:
|
||||
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
|
||||
|
||||
_Command = TypeVar("_Command", bound="SMP_Command")
|
||||
|
||||
@classmethod
|
||||
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
|
||||
subclass.name = subclass.__name__.upper()
|
||||
subclass.code = key_with_value(SMP_COMMAND_NAMES, subclass.name)
|
||||
if subclass.code is None:
|
||||
raise KeyError(
|
||||
f'Command name {subclass.name} not found in SMP_COMMAND_NAMES'
|
||||
)
|
||||
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
|
||||
|
||||
subclass.name = subclass.__name__.upper()
|
||||
# Register a factory for this class
|
||||
SMP_Command.smp_classes[subclass.code] = subclass
|
||||
|
||||
return subclass
|
||||
|
||||
@property
|
||||
@@ -308,19 +259,17 @@ class SMP_Pairing_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
|
||||
'''
|
||||
|
||||
io_capability: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
|
||||
)
|
||||
code = CommandCode.PAIRING_REQUEST
|
||||
|
||||
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
|
||||
oob_data_flag: int = field(metadata=metadata(1))
|
||||
auth_req: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
||||
)
|
||||
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
||||
initiator_key_distribution: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
||||
initiator_key_distribution: KeyDistribution = field(
|
||||
metadata=KeyDistribution.type_metadata(1)
|
||||
)
|
||||
responder_key_distribution: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
||||
responder_key_distribution: KeyDistribution = field(
|
||||
metadata=KeyDistribution.type_metadata(1)
|
||||
)
|
||||
|
||||
|
||||
@@ -332,19 +281,17 @@ class SMP_Pairing_Response_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
|
||||
'''
|
||||
|
||||
io_capability: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
|
||||
)
|
||||
code = CommandCode.PAIRING_RESPONSE
|
||||
|
||||
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
|
||||
oob_data_flag: int = field(metadata=metadata(1))
|
||||
auth_req: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
||||
)
|
||||
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||
maximum_encryption_key_size: int = field(metadata=metadata(1))
|
||||
initiator_key_distribution: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
||||
initiator_key_distribution: KeyDistribution = field(
|
||||
metadata=KeyDistribution.type_metadata(1)
|
||||
)
|
||||
responder_key_distribution: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
|
||||
responder_key_distribution: KeyDistribution = field(
|
||||
metadata=KeyDistribution.type_metadata(1)
|
||||
)
|
||||
|
||||
|
||||
@@ -356,6 +303,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
|
||||
'''
|
||||
|
||||
code = CommandCode.PAIRING_CONFIRM
|
||||
|
||||
confirm_value: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -367,6 +316,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
|
||||
'''
|
||||
|
||||
code = CommandCode.PAIRING_RANDOM
|
||||
|
||||
random_value: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -378,7 +329,9 @@ class SMP_Pairing_Failed_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
|
||||
'''
|
||||
|
||||
reason: int = field(metadata=metadata({'size': 1, 'mapper': error_name}))
|
||||
code = CommandCode.PAIRING_FAILED
|
||||
|
||||
reason: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -389,6 +342,8 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
|
||||
'''
|
||||
|
||||
code = CommandCode.PAIRING_PUBLIC_KEY
|
||||
|
||||
public_key_x: bytes = field(metadata=metadata(32))
|
||||
public_key_y: bytes = field(metadata=metadata(32))
|
||||
|
||||
@@ -401,6 +356,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
|
||||
'''
|
||||
|
||||
code = CommandCode.PAIRING_DHKEY_CHECK
|
||||
|
||||
dhkey_check: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -412,10 +369,10 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
|
||||
'''
|
||||
|
||||
notification_type: int = field(
|
||||
metadata=metadata(
|
||||
{'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}
|
||||
)
|
||||
code = CommandCode.PAIRING_KEYPRESS_NOTIFICATION
|
||||
|
||||
notification_type: KeypressNotificationType = field(
|
||||
metadata=KeypressNotificationType.type_metadata(1)
|
||||
)
|
||||
|
||||
|
||||
@@ -427,6 +384,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
|
||||
'''
|
||||
|
||||
code = CommandCode.ENCRYPTION_INFORMATION
|
||||
|
||||
long_term_key: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -438,6 +397,8 @@ class SMP_Master_Identification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
|
||||
'''
|
||||
|
||||
code = CommandCode.MASTER_IDENTIFICATION
|
||||
|
||||
ediv: int = field(metadata=metadata(2))
|
||||
rand: bytes = field(metadata=metadata(8))
|
||||
|
||||
@@ -450,6 +411,8 @@ class SMP_Identity_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
|
||||
'''
|
||||
|
||||
code = CommandCode.IDENTITY_INFORMATION
|
||||
|
||||
identity_resolving_key: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -461,6 +424,8 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
|
||||
'''
|
||||
|
||||
code = CommandCode.IDENTITY_ADDRESS_INFORMATION
|
||||
|
||||
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
|
||||
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
|
||||
|
||||
@@ -473,6 +438,8 @@ class SMP_Signing_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
|
||||
'''
|
||||
|
||||
code = CommandCode.SIGNING_INFORMATION
|
||||
|
||||
signature_key: bytes = field(metadata=metadata(16))
|
||||
|
||||
|
||||
@@ -484,33 +451,22 @@ class SMP_Security_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
|
||||
'''
|
||||
|
||||
auth_req: int = field(
|
||||
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
|
||||
)
|
||||
code = CommandCode.SECURITY_REQUEST
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
|
||||
value = 0
|
||||
if bonding:
|
||||
value |= SMP_BONDING_AUTHREQ
|
||||
if mitm:
|
||||
value |= SMP_MITM_AUTHREQ
|
||||
if sc:
|
||||
value |= SMP_SC_AUTHREQ
|
||||
if keypress:
|
||||
value |= SMP_KEYPRESS_AUTHREQ
|
||||
if ct2:
|
||||
value |= SMP_CT2_AUTHREQ
|
||||
return value
|
||||
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AddressResolver:
|
||||
def __init__(self, resolving_keys):
|
||||
def __init__(self, resolving_keys: Sequence[tuple[bytes, Address]]) -> None:
|
||||
self.resolving_keys = resolving_keys
|
||||
|
||||
def resolve(self, address):
|
||||
def can_resolve_to(self, address: Address) -> bool:
|
||||
return any(
|
||||
resolved_address == address for _, resolved_address in self.resolving_keys
|
||||
)
|
||||
|
||||
def resolve(self, address: Address) -> Address | None:
|
||||
address_bytes = bytes(address)
|
||||
hash_part = address_bytes[0:3]
|
||||
prand = address_bytes[3:6]
|
||||
@@ -671,8 +627,8 @@ class Session:
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key: bytes | None = None
|
||||
self.maximum_encryption_key_size: int = 0
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.initiator_key_distribution: KeyDistribution = KeyDistribution(0)
|
||||
self.responder_key_distribution: KeyDistribution = KeyDistribution(0)
|
||||
self.peer_random_value: bytes | None = None
|
||||
self.peer_public_key_x: bytes = bytes(32)
|
||||
self.peer_public_key_y = bytes(32)
|
||||
@@ -723,10 +679,10 @@ class Session:
|
||||
)
|
||||
|
||||
# Key Distribution (default values before negotiation)
|
||||
self.initiator_key_distribution = (
|
||||
self.initiator_key_distribution = KeyDistribution(
|
||||
pairing_config.delegate.local_initiator_key_distribution
|
||||
)
|
||||
self.responder_key_distribution = (
|
||||
self.responder_key_distribution = KeyDistribution(
|
||||
pairing_config.delegate.local_responder_key_distribution
|
||||
)
|
||||
|
||||
@@ -738,7 +694,7 @@ class Session:
|
||||
self.ct2: bool = False
|
||||
|
||||
# I/O Capabilities
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
self.io_capability = IoCapability(pairing_config.delegate.io_capability)
|
||||
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
|
||||
# OOB
|
||||
@@ -817,8 +773,14 @@ class Session:
|
||||
return self.nx[0 if self.is_responder else 1]
|
||||
|
||||
@property
|
||||
def auth_req(self) -> int:
|
||||
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
|
||||
def auth_req(self) -> AuthReq:
|
||||
return AuthReq.from_booleans(
|
||||
bonding=self.bonding,
|
||||
sc=self.sc,
|
||||
mitm=self.mitm,
|
||||
keypress=self.keypress,
|
||||
ct2=self.ct2,
|
||||
)
|
||||
|
||||
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
|
||||
if not self.sc and not self.completed:
|
||||
@@ -838,7 +800,7 @@ class Session:
|
||||
if self.connection.transport == PhysicalTransport.BR_EDR:
|
||||
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
||||
return
|
||||
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
|
||||
if (not self.mitm) and (auth_req & AuthReq.MITM == 0):
|
||||
self.pairing_method = PairingMethod.JUST_WORKS
|
||||
return
|
||||
|
||||
@@ -856,7 +818,7 @@ class Session:
|
||||
self.passkey_display = details[1 if self.is_initiator else 2]
|
||||
|
||||
def check_expected_value(
|
||||
self, expected: bytes, received: bytes, error: int
|
||||
self, expected: bytes, received: bytes, error: ErrorCode
|
||||
) -> bool:
|
||||
logger.debug(f'expected={expected.hex()} got={received.hex()}')
|
||||
if expected != received:
|
||||
@@ -876,7 +838,7 @@ class Session:
|
||||
except Exception:
|
||||
logger.exception('exception while confirm')
|
||||
|
||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
|
||||
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
@@ -895,7 +857,7 @@ class Session:
|
||||
except Exception:
|
||||
logger.exception('exception while prompting')
|
||||
|
||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
|
||||
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
@@ -906,13 +868,13 @@ class Session:
|
||||
passkey = await self.pairing_config.delegate.get_number()
|
||||
if passkey is None:
|
||||
logger.debug('Passkey request rejected')
|
||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
|
||||
return
|
||||
logger.debug(f'user input: {passkey}')
|
||||
next_steps(passkey)
|
||||
except Exception:
|
||||
logger.exception('exception while prompting')
|
||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
|
||||
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
@@ -967,7 +929,7 @@ class Session:
|
||||
def send_command(self, command: SMP_Command) -> None:
|
||||
self.manager.send_command(self.connection, command)
|
||||
|
||||
def send_pairing_failed(self, error: int) -> None:
|
||||
def send_pairing_failed(self, error: ErrorCode) -> None:
|
||||
self.send_command(SMP_Pairing_Failed_Command(reason=error))
|
||||
self.on_pairing_failure(error)
|
||||
|
||||
@@ -1139,7 +1101,7 @@ class Session:
|
||||
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
||||
)
|
||||
self.send_pairing_failed(
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
|
||||
ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
|
||||
)
|
||||
else:
|
||||
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
||||
@@ -1150,14 +1112,14 @@ class Session:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
if (
|
||||
self.connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
and self.initiator_key_distribution & KeyDistribution.ENC_KEY
|
||||
):
|
||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||
self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
elif not self.sc:
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||
if self.initiator_key_distribution & KeyDistribution.ENC_KEY:
|
||||
self.send_command(
|
||||
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
||||
)
|
||||
@@ -1168,7 +1130,7 @@ class Session:
|
||||
)
|
||||
|
||||
# Distribute IRK & BD ADDR
|
||||
if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
||||
if self.initiator_key_distribution & KeyDistribution.ID_KEY:
|
||||
self.send_command(
|
||||
SMP_Identity_Information_Command(
|
||||
identity_resolving_key=self.manager.device.irk
|
||||
@@ -1178,25 +1140,25 @@ class Session:
|
||||
|
||||
# Distribute CSRK
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
||||
if self.initiator_key_distribution & KeyDistribution.SIGN_KEY:
|
||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
if self.initiator_key_distribution & KeyDistribution.LINK_KEY:
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
else:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
if (
|
||||
self.connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
and self.responder_key_distribution & KeyDistribution.ENC_KEY
|
||||
):
|
||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||
self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
elif not self.sc:
|
||||
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||
if self.responder_key_distribution & KeyDistribution.ENC_KEY:
|
||||
self.send_command(
|
||||
SMP_Encryption_Information_Command(long_term_key=self.ltk)
|
||||
)
|
||||
@@ -1207,7 +1169,7 @@ class Session:
|
||||
)
|
||||
|
||||
# Distribute IRK & BD ADDR
|
||||
if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
||||
if self.responder_key_distribution & KeyDistribution.ID_KEY:
|
||||
self.send_command(
|
||||
SMP_Identity_Information_Command(
|
||||
identity_resolving_key=self.manager.device.irk
|
||||
@@ -1217,30 +1179,30 @@ class Session:
|
||||
|
||||
# Distribute CSRK
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
||||
if self.responder_key_distribution & KeyDistribution.SIGN_KEY:
|
||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
if self.responder_key_distribution & KeyDistribution.LINK_KEY:
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
self.peer_expected_distributions = []
|
||||
if not self.sc and self.connection.transport == PhysicalTransport.LE:
|
||||
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
|
||||
if key_distribution_flags & KeyDistribution.ENC_KEY != 0:
|
||||
self.peer_expected_distributions.append(
|
||||
SMP_Encryption_Information_Command
|
||||
)
|
||||
self.peer_expected_distributions.append(
|
||||
SMP_Master_Identification_Command
|
||||
)
|
||||
if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0:
|
||||
if key_distribution_flags & KeyDistribution.ID_KEY != 0:
|
||||
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
|
||||
self.peer_expected_distributions.append(
|
||||
SMP_Identity_Address_Information_Command
|
||||
)
|
||||
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
|
||||
if key_distribution_flags & KeyDistribution.SIGN_KEY != 0:
|
||||
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
|
||||
logger.debug(
|
||||
'expecting distributions: '
|
||||
@@ -1253,7 +1215,7 @@ class Session:
|
||||
logger.warning(
|
||||
color('received key distribution on a non-encrypted connection', 'red')
|
||||
)
|
||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
|
||||
return
|
||||
|
||||
# Check that this command class is expected
|
||||
@@ -1273,7 +1235,7 @@ class Session:
|
||||
'red',
|
||||
)
|
||||
)
|
||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
|
||||
|
||||
async def pair(self) -> None:
|
||||
# Start pairing as an initiator
|
||||
@@ -1384,34 +1346,56 @@ class Session:
|
||||
)
|
||||
await self.manager.on_pairing(self, peer_address, keys)
|
||||
|
||||
def on_pairing_failure(self, reason: int) -> None:
|
||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
||||
def on_pairing_failure(self, reason: ErrorCode) -> None:
|
||||
logger.warning('pairing failure (%s)', reason.name)
|
||||
|
||||
if self.completed:
|
||||
return
|
||||
|
||||
self.completed = True
|
||||
|
||||
error = ProtocolError(reason, 'smp', error_name(reason))
|
||||
error = ProtocolError(reason, 'smp', reason.name)
|
||||
if self.pairing_result is not None and not self.pairing_result.done():
|
||||
self.pairing_result.set_exception(error)
|
||||
self.manager.on_pairing_failure(self, reason)
|
||||
|
||||
def on_smp_command(self, command: SMP_Command) -> None:
|
||||
# Find the handler method
|
||||
handler_name = f'on_{command.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is not None:
|
||||
try:
|
||||
handler(command)
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
response = SMP_Pairing_Failed_Command(
|
||||
reason=SMP_UNSPECIFIED_REASON_ERROR
|
||||
)
|
||||
self.send_command(response)
|
||||
else:
|
||||
logger.error(color('SMP command not handled???', 'red'))
|
||||
try:
|
||||
match command:
|
||||
case SMP_Pairing_Request_Command():
|
||||
self.on_smp_pairing_request_command(command)
|
||||
case SMP_Pairing_Response_Command():
|
||||
self.on_smp_pairing_response_command(command)
|
||||
case SMP_Pairing_Confirm_Command():
|
||||
self.on_smp_pairing_confirm_command(command)
|
||||
case SMP_Pairing_Random_Command():
|
||||
self.on_smp_pairing_random_command(command)
|
||||
case SMP_Pairing_Failed_Command():
|
||||
self.on_smp_pairing_failed_command(command)
|
||||
case SMP_Encryption_Information_Command():
|
||||
self.on_smp_encryption_information_command(command)
|
||||
case SMP_Master_Identification_Command():
|
||||
self.on_smp_master_identification_command(command)
|
||||
case SMP_Identity_Information_Command():
|
||||
self.on_smp_identity_information_command(command)
|
||||
case SMP_Identity_Address_Information_Command():
|
||||
self.on_smp_identity_address_information_command(command)
|
||||
case SMP_Signing_Information_Command():
|
||||
self.on_smp_signing_information_command(command)
|
||||
case SMP_Pairing_Public_Key_Command():
|
||||
self.on_smp_pairing_public_key_command(command)
|
||||
case SMP_Pairing_DHKey_Check_Command():
|
||||
self.on_smp_pairing_dhkey_check_command(command)
|
||||
# case SMP_Security_Request_Command():
|
||||
# self.on_smp_security_request_command(command)
|
||||
# case SMP_Pairing_Keypress_Notification_Command():
|
||||
# self.on_smp_pairing_keypress_notification_command(command)
|
||||
case _:
|
||||
logger.error(color('SMP command not handled', 'red'))
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
response = SMP_Pairing_Failed_Command(reason=ErrorCode.UNSPECIFIED_REASON)
|
||||
self.send_command(response)
|
||||
|
||||
def on_smp_pairing_request_command(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
@@ -1431,16 +1415,16 @@ class Session:
|
||||
accepted = False
|
||||
if not accepted:
|
||||
logger.debug('pairing rejected by delegate')
|
||||
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.PAIRING_NOT_SUPPORTED)
|
||||
return
|
||||
|
||||
# Save the request
|
||||
self.preq = bytes(command)
|
||||
|
||||
# Bonding and SC require both sides to request/support it
|
||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
|
||||
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
|
||||
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
|
||||
self.ct2 = self.ct2 and (command.auth_req & AuthReq.CT2 != 0)
|
||||
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
@@ -1451,7 +1435,7 @@ class Session:
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
@@ -1470,8 +1454,11 @@ class Session:
|
||||
(
|
||||
self.initiator_key_distribution,
|
||||
self.responder_key_distribution,
|
||||
) = await self.pairing_config.delegate.key_distribution_response(
|
||||
command.initiator_key_distribution, command.responder_key_distribution
|
||||
) = map(
|
||||
KeyDistribution,
|
||||
await self.pairing_config.delegate.key_distribution_response(
|
||||
command.initiator_key_distribution, command.responder_key_distribution
|
||||
),
|
||||
)
|
||||
self.compute_peer_expected_distributions(self.initiator_key_distribution)
|
||||
|
||||
@@ -1509,8 +1496,8 @@ class Session:
|
||||
self.peer_io_capability = command.io_capability
|
||||
|
||||
# Bonding and SC require both sides to request/support it
|
||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
|
||||
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
|
||||
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
@@ -1521,7 +1508,7 @@ class Session:
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
@@ -1541,7 +1528,7 @@ class Session:
|
||||
command.responder_key_distribution & ~self.responder_key_distribution != 0
|
||||
):
|
||||
# The response isn't a subset of the request
|
||||
self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR)
|
||||
self.send_pairing_failed(ErrorCode.INVALID_PARAMETERS)
|
||||
return
|
||||
self.initiator_key_distribution = command.initiator_key_distribution
|
||||
self.responder_key_distribution = command.responder_key_distribution
|
||||
@@ -1619,7 +1606,7 @@ class Session:
|
||||
)
|
||||
assert self.confirm_value
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1660,7 +1647,7 @@ class Session:
|
||||
self.pkb, self.pka, command.random_value, bytes([0])
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||
):
|
||||
return
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
@@ -1673,7 +1660,7 @@ class Session:
|
||||
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1702,7 +1689,7 @@ class Session:
|
||||
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1819,7 +1806,7 @@ class Session:
|
||||
if not self.check_expected_value(
|
||||
self.peer_oob_data.c,
|
||||
confirm_verifier,
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
ErrorCode.CONFIRM_VALUE_FAILED,
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1853,7 +1840,7 @@ class Session:
|
||||
expected = self.eb if self.is_initiator else self.ea
|
||||
assert expected
|
||||
if not self.check_expected_value(
|
||||
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
|
||||
expected, command.dhkey_check, ErrorCode.DHKEY_CHECK_FAILED
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1957,7 +1944,7 @@ class Manager(utils.EventEmitter):
|
||||
)
|
||||
|
||||
# Security request is more than just pairing, so let applications handle them
|
||||
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||
if command.code == CommandCode.SECURITY_REQUEST:
|
||||
self.on_smp_security_request_command(
|
||||
connection, cast(SMP_Security_Request_Command, command)
|
||||
)
|
||||
@@ -1997,15 +1984,13 @@ class Manager(utils.EventEmitter):
|
||||
def request_pairing(self, connection: Connection) -> None:
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config:
|
||||
auth_req = smp_auth_req(
|
||||
pairing_config.bonding,
|
||||
pairing_config.mitm,
|
||||
pairing_config.sc,
|
||||
False,
|
||||
False,
|
||||
auth_req = AuthReq.from_booleans(
|
||||
bonding=pairing_config.bonding,
|
||||
sc=pairing_config.sc,
|
||||
mitm=pairing_config.mitm,
|
||||
)
|
||||
else:
|
||||
auth_req = 0
|
||||
auth_req = AuthReq(0)
|
||||
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
|
||||
|
||||
def on_session_start(self, session: Session) -> None:
|
||||
@@ -2021,7 +2006,7 @@ class Manager(utils.EventEmitter):
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session: Session, reason: int) -> None:
|
||||
def on_pairing_failure(self, session: Session, reason: ErrorCode) -> None:
|
||||
self.device.on_pairing_failure(session.connection, reason)
|
||||
|
||||
def on_session_end(self, session: Session) -> None:
|
||||
|
||||
@@ -25,7 +25,7 @@ import sys
|
||||
import websockets.asyncio.server
|
||||
|
||||
import bumble.logging
|
||||
from bumble import a2dp, avc, avdtp, avrcp, utils
|
||||
from bumble import a2dp, avc, avdtp, avrcp, sdp, utils
|
||||
from bumble.core import PhysicalTransport
|
||||
from bumble.device import Device
|
||||
from bumble.transport import open_transport
|
||||
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def sdp_records():
|
||||
def sdp_records() -> dict[int, list[sdp.ServiceAttribute]]:
|
||||
a2dp_sink_service_record_handle = 0x00010001
|
||||
avrcp_controller_service_record_handle = 0x00010002
|
||||
avrcp_target_service_record_handle = 0x00010003
|
||||
@@ -43,17 +43,17 @@ def sdp_records():
|
||||
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
|
||||
a2dp_sink_service_record_handle
|
||||
),
|
||||
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records(
|
||||
avrcp_controller_service_record_handle: avrcp.ControllerServiceSdpRecord(
|
||||
avrcp_controller_service_record_handle
|
||||
),
|
||||
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records(
|
||||
avrcp_controller_service_record_handle
|
||||
),
|
||||
).to_service_attributes(),
|
||||
avrcp_target_service_record_handle: avrcp.TargetServiceSdpRecord(
|
||||
avrcp_target_service_record_handle
|
||||
).to_service_attributes(),
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def codec_capabilities():
|
||||
def codec_capabilities() -> avdtp.MediaCodecCapabilities:
|
||||
return avdtp.MediaCodecCapabilities(
|
||||
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
|
||||
@@ -81,20 +81,22 @@ def codec_capabilities():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_avdtp_connection(server):
|
||||
def on_avdtp_connection(server: avdtp.Protocol) -> None:
|
||||
# Add a sink endpoint to the server
|
||||
sink = server.add_sink(codec_capabilities())
|
||||
sink.on('rtp_packet', on_rtp_packet)
|
||||
sink.on(sink.EVENT_RTP_PACKET, on_rtp_packet)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_rtp_packet(packet):
|
||||
def on_rtp_packet(packet: avdtp.MediaPacket) -> None:
|
||||
print(f'RTP: {packet}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer):
|
||||
async def get_supported_events():
|
||||
def on_avrcp_start(
|
||||
avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer
|
||||
) -> None:
|
||||
async def get_supported_events() -> None:
|
||||
events = await avrcp_protocol.get_supported_events()
|
||||
print("SUPPORTED EVENTS:", events)
|
||||
websocket_server.send_message(
|
||||
@@ -130,14 +132,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
|
||||
utils.AsyncRunner.spawn(get_supported_events())
|
||||
|
||||
async def monitor_track_changed():
|
||||
async for identifier in avrcp_protocol.monitor_track_changed():
|
||||
print("TRACK CHANGED:", identifier.hex())
|
||||
async def monitor_track_changed() -> None:
|
||||
async for uid in avrcp_protocol.monitor_track_changed():
|
||||
print("TRACK CHANGED:", hex(uid))
|
||||
websocket_server.send_message(
|
||||
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
|
||||
{"type": "track-changed", "params": {"identifier": hex(uid)}}
|
||||
)
|
||||
|
||||
async def monitor_playback_status():
|
||||
async def monitor_playback_status() -> None:
|
||||
async for playback_status in avrcp_protocol.monitor_playback_status():
|
||||
print("PLAYBACK STATUS CHANGED:", playback_status.name)
|
||||
websocket_server.send_message(
|
||||
@@ -147,7 +149,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
}
|
||||
)
|
||||
|
||||
async def monitor_playback_position():
|
||||
async def monitor_playback_position() -> None:
|
||||
async for playback_position in avrcp_protocol.monitor_playback_position(
|
||||
playback_interval=1
|
||||
):
|
||||
@@ -159,7 +161,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
}
|
||||
)
|
||||
|
||||
async def monitor_player_application_settings():
|
||||
async def monitor_player_application_settings() -> None:
|
||||
async for settings in avrcp_protocol.monitor_player_application_settings():
|
||||
print("PLAYER APPLICATION SETTINGS:", settings)
|
||||
settings_as_dict = [
|
||||
@@ -173,14 +175,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
}
|
||||
)
|
||||
|
||||
async def monitor_available_players():
|
||||
async def monitor_available_players() -> None:
|
||||
async for _ in avrcp_protocol.monitor_available_players():
|
||||
print("AVAILABLE PLAYERS CHANGED")
|
||||
websocket_server.send_message(
|
||||
{"type": "available-players-changed", "params": {}}
|
||||
)
|
||||
|
||||
async def monitor_addressed_player():
|
||||
async def monitor_addressed_player() -> None:
|
||||
async for player in avrcp_protocol.monitor_addressed_player():
|
||||
print("ADDRESSED PLAYER CHANGED")
|
||||
websocket_server.send_message(
|
||||
@@ -195,7 +197,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
}
|
||||
)
|
||||
|
||||
async def monitor_uids():
|
||||
async def monitor_uids() -> None:
|
||||
async for uid_counter in avrcp_protocol.monitor_uids():
|
||||
print("UIDS CHANGED")
|
||||
websocket_server.send_message(
|
||||
@@ -207,7 +209,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
|
||||
}
|
||||
)
|
||||
|
||||
async def monitor_volume():
|
||||
async def monitor_volume() -> None:
|
||||
async for volume in avrcp_protocol.monitor_volume():
|
||||
print("VOLUME CHANGED:", volume)
|
||||
websocket_server.send_message(
|
||||
@@ -360,7 +362,7 @@ async def main() -> None:
|
||||
|
||||
# Create a listener to wait for AVDTP connections
|
||||
listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
|
||||
listener.on('connection', on_avdtp_connection)
|
||||
listener.on(listener.EVENT_CONNECTION, on_avdtp_connection)
|
||||
|
||||
avrcp_delegate = Delegate()
|
||||
avrcp_protocol = avrcp.Protocol(avrcp_delegate)
|
||||
|
||||
@@ -37,7 +37,7 @@ dependencies = [
|
||||
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
|
||||
"pyserial >= 3.5; platform_system!='Emscripten'",
|
||||
"pyusb >= 1.2; platform_system!='Emscripten'",
|
||||
"tomli ~= 2.2.1; platform_system!='Emscripten'",
|
||||
"tomli ~= 2.2.1; platform_system!='Emscripten' and python_version<'3.11'",
|
||||
"websockets >= 15.0.1; platform_system!='Emscripten'",
|
||||
]
|
||||
|
||||
|
||||
4
rust/Cargo.lock
generated
4
rust/Cargo.lock
generated
@@ -221,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.5.0"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
|
||||
@@ -30,7 +30,7 @@ hex = "0.4.3"
|
||||
itertools = "0.11.0"
|
||||
lazy_static = "1.4.0"
|
||||
thiserror = "1.0.41"
|
||||
bytes = "1.5.0"
|
||||
bytes = "1.11.1"
|
||||
pdl-derive = "0.2.0"
|
||||
pdl-runtime = "0.2.0"
|
||||
futures = "0.3.28"
|
||||
|
||||
@@ -17,8 +17,10 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
from collections.abc import Sequence
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -117,8 +119,6 @@ class TwoDevices(test_utils.TwoDevices):
|
||||
scope=avrcp.Scope.NOW_PLAYING,
|
||||
uid=0,
|
||||
uid_counter=1,
|
||||
start_item=0,
|
||||
end_item=0,
|
||||
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
|
||||
),
|
||||
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
|
||||
@@ -135,7 +135,7 @@ def test_command(command: avrcp.Command):
|
||||
"event,",
|
||||
[
|
||||
avrcp.UidsChangedEvent(uid_counter=7),
|
||||
avrcp.TrackChangedEvent(identifier=b'12356'),
|
||||
avrcp.TrackChangedEvent(uid=12356),
|
||||
avrcp.VolumeChangedEvent(volume=9),
|
||||
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
|
||||
avrcp.AddressedPlayerChangedEvent(
|
||||
@@ -422,6 +422,47 @@ def test_passthrough_commands():
|
||||
assert bytes(parsed) == play_pressed_bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_sdp_records():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
# Add SDP records to device 1
|
||||
controller_record = avrcp.ControllerServiceSdpRecord(
|
||||
service_record_handle=0x10001,
|
||||
avctp_version=(1, 4),
|
||||
avrcp_version=(1, 6),
|
||||
supported_features=(
|
||||
avrcp.ControllerFeatures.CATEGORY_1
|
||||
| avrcp.ControllerFeatures.SUPPORTS_BROWSING
|
||||
),
|
||||
)
|
||||
target_record = avrcp.TargetServiceSdpRecord(
|
||||
service_record_handle=0x10002,
|
||||
avctp_version=(1, 4),
|
||||
avrcp_version=(1, 6),
|
||||
supported_features=(
|
||||
avrcp.TargetFeatures.CATEGORY_1 | avrcp.TargetFeatures.SUPPORTS_BROWSING
|
||||
),
|
||||
)
|
||||
|
||||
two_devices.devices[1].sdp_service_records = {
|
||||
0x10001: controller_record.to_service_attributes(),
|
||||
0x10002: target_record.to_service_attributes(),
|
||||
}
|
||||
|
||||
# Find records from device 0
|
||||
controller_records = await avrcp.ControllerServiceSdpRecord.find(
|
||||
two_devices.connections[0]
|
||||
)
|
||||
assert len(controller_records) == 1
|
||||
assert controller_records[0] == controller_record
|
||||
|
||||
target_records = await avrcp.TargetServiceSdpRecord.find(two_devices.connections[0])
|
||||
assert len(target_records) == 1
|
||||
assert target_records[0] == target_record
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_supported_events():
|
||||
@@ -436,6 +477,340 @@ async def test_get_supported_events():
|
||||
assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_passthrough_key_event():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
q = asyncio.Queue[tuple[avc.PassThroughFrame.OperationId, bool, bytes]]()
|
||||
|
||||
class Delegate(avrcp.Delegate):
|
||||
async def on_key_event(
|
||||
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
|
||||
) -> None:
|
||||
q.put_nowait((key, pressed, data))
|
||||
|
||||
two_devices.protocols[1].delegate = Delegate()
|
||||
|
||||
for key, pressed in [
|
||||
(avc.PassThroughFrame.OperationId.PLAY, True),
|
||||
(avc.PassThroughFrame.OperationId.PLAY, False),
|
||||
(avc.PassThroughFrame.OperationId.PAUSE, True),
|
||||
(avc.PassThroughFrame.OperationId.PAUSE, False),
|
||||
]:
|
||||
await two_devices.protocols[0].send_key_event(key, pressed)
|
||||
assert (await q.get()) == (key, pressed, b'')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_passthrough_key_event_rejected():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
class Delegate(avrcp.Delegate):
|
||||
async def on_key_event(
|
||||
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
|
||||
) -> None:
|
||||
raise avrcp.Delegate.AvcError(avc.ResponseFrame.ResponseCode.REJECTED)
|
||||
|
||||
two_devices.protocols[1].delegate = Delegate()
|
||||
|
||||
response = await two_devices.protocols[0].send_key_event(
|
||||
avc.PassThroughFrame.OperationId.PLAY, True
|
||||
)
|
||||
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_passthrough_key_event_exception():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
class Delegate(avrcp.Delegate):
|
||||
async def on_key_event(
|
||||
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
|
||||
) -> None:
|
||||
raise Exception()
|
||||
|
||||
two_devices.protocols[1].delegate = Delegate()
|
||||
|
||||
response = await two_devices.protocols[0].send_key_event(
|
||||
avc.PassThroughFrame.OperationId.PLAY, True
|
||||
)
|
||||
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_volume():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
|
||||
response = await two_devices.protocols[1].send_avrcp_command(
|
||||
avc.CommandFrame.CommandType.CONTROL, avrcp.SetAbsoluteVolumeCommand(volume)
|
||||
)
|
||||
assert isinstance(response.response, avrcp.SetAbsoluteVolumeResponse)
|
||||
assert response.response.volume == volume
|
||||
assert two_devices.protocols[0].delegate.volume == volume
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_playback_status():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
for status in avrcp.PlayStatus:
|
||||
two_devices.protocols[0].delegate.playback_status = status
|
||||
response = await two_devices.protocols[1].get_play_status()
|
||||
assert response.play_status == status
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_supported_company_ids():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
for status in avrcp.PlayStatus:
|
||||
two_devices.protocols[0].delegate = avrcp.Delegate(
|
||||
supported_company_ids=[avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
|
||||
)
|
||||
supported_company_ids = await two_devices.protocols[
|
||||
1
|
||||
].get_supported_company_ids()
|
||||
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_player_application_settings():
|
||||
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
expected_settings = {
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: [
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.SINGLE_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.OFF,
|
||||
],
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: [
|
||||
avrcp.ApplicationSetting.ShuffleOnOffStatus.OFF,
|
||||
avrcp.ApplicationSetting.ShuffleOnOffStatus.ALL_TRACKS_SHUFFLE,
|
||||
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||
],
|
||||
}
|
||||
two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
supported_player_app_settings=expected_settings
|
||||
)
|
||||
actual_settings = await two_devices.protocols[
|
||||
0
|
||||
].list_supported_player_app_settings()
|
||||
assert actual_settings == expected_settings
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_set_player_app_settings():
|
||||
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate
|
||||
await two_devices.protocols[0].send_avrcp_command(
|
||||
avc.CommandFrame.CommandType.CONTROL,
|
||||
avrcp.SetPlayerApplicationSettingValueCommand(
|
||||
attribute=[
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||
],
|
||||
value=[
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||
],
|
||||
),
|
||||
)
|
||||
expected_settings = {
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
|
||||
}
|
||||
assert delegate.player_app_settings == expected_settings
|
||||
|
||||
actual_settings = await two_devices.protocols[0].get_player_app_settings(
|
||||
[
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
|
||||
]
|
||||
)
|
||||
assert actual_settings == expected_settings
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_item():
|
||||
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate
|
||||
|
||||
with mock.patch.object(delegate, delegate.play_item.__name__) as play_item_mock:
|
||||
await two_devices.protocols[0].send_avrcp_command(
|
||||
avc.CommandFrame.CommandType.CONTROL,
|
||||
avrcp.PlayItemCommand(
|
||||
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
|
||||
),
|
||||
)
|
||||
|
||||
play_item_mock.assert_called_once_with(
|
||||
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_volume():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
two_devices.protocols[1].delegate = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
|
||||
volume_iter = two_devices.protocols[0].monitor_volume()
|
||||
|
||||
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
|
||||
# Interim
|
||||
two_devices.protocols[1].delegate.volume = 0
|
||||
assert (await anext(volume_iter)) == 0
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_volume_changed(volume)
|
||||
assert (await anext(volume_iter)) == volume
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_playback_status():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
[avrcp.EventId.PLAYBACK_STATUS_CHANGED]
|
||||
)
|
||||
playback_status_iter = two_devices.protocols[0].monitor_playback_status()
|
||||
|
||||
for playback_status in avrcp.PlayStatus:
|
||||
# Interim
|
||||
two_devices.protocols[1].delegate.playback_status = avrcp.PlayStatus.STOPPED
|
||||
assert (await anext(playback_status_iter)) == avrcp.PlayStatus.STOPPED
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_playback_status_changed(playback_status)
|
||||
assert (await anext(playback_status_iter)) == playback_status
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_now_playing_content():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
[avrcp.EventId.NOW_PLAYING_CONTENT_CHANGED]
|
||||
)
|
||||
now_playing_iter = two_devices.protocols[0].monitor_now_playing_content()
|
||||
|
||||
for _ in range(2):
|
||||
# Interim
|
||||
await anext(now_playing_iter)
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_now_playing_content_changed()
|
||||
await anext(now_playing_iter)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_track_changed():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
[avrcp.EventId.TRACK_CHANGED]
|
||||
)
|
||||
delegate.current_track_uid = avrcp.TrackChangedEvent.NO_TRACK
|
||||
track_iter = two_devices.protocols[0].monitor_track_changed()
|
||||
|
||||
# Interim
|
||||
assert (await anext(track_iter)) == avrcp.TrackChangedEvent.NO_TRACK
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_track_changed(1)
|
||||
assert (await anext(track_iter)) == 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_uid_changed():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
[avrcp.EventId.UIDS_CHANGED]
|
||||
)
|
||||
delegate.uid_counter = 0
|
||||
uid_iter = two_devices.protocols[0].monitor_uids()
|
||||
|
||||
# Interim
|
||||
assert (await anext(uid_iter)) == 0
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_uids_changed(1)
|
||||
assert (await anext(uid_iter)) == 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_addressed_player():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
[avrcp.EventId.ADDRESSED_PLAYER_CHANGED]
|
||||
)
|
||||
delegate.uid_counter = 0
|
||||
delegate.addressed_player_id = 0
|
||||
addressed_player_iter = two_devices.protocols[0].monitor_addressed_player()
|
||||
|
||||
# Interim
|
||||
assert (
|
||||
await anext(addressed_player_iter)
|
||||
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=0, uid_counter=0)
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_addressed_player_changed(
|
||||
avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
|
||||
)
|
||||
assert (
|
||||
await anext(addressed_player_iter)
|
||||
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_player_app_settings():
|
||||
two_devices = await TwoDevices.create_with_avdtp()
|
||||
|
||||
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
|
||||
supported_events=[avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED]
|
||||
)
|
||||
delegate.player_app_settings = {
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
|
||||
}
|
||||
settings_iter = two_devices.protocols[0].monitor_player_application_settings()
|
||||
|
||||
# Interim
|
||||
interim = await anext(settings_iter)
|
||||
assert interim[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
|
||||
assert (
|
||||
interim[0].value_id
|
||||
== avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
|
||||
)
|
||||
|
||||
# Changed
|
||||
two_devices.protocols[1].notify_player_application_settings_changed(
|
||||
[
|
||||
avrcp.PlayerApplicationSettingChangedEvent.Setting(
|
||||
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
|
||||
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
|
||||
)
|
||||
]
|
||||
)
|
||||
changed = await anext(settings_iter)
|
||||
assert changed[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
|
||||
assert changed[0].value_id == avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_frame_parser()
|
||||
|
||||
@@ -309,6 +309,27 @@ async def test_legacy_advertising_disconnection(auto_restart):
|
||||
assert not devices[0].is_advertising
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_le_multiple_connects():
|
||||
devices = TwoDevices()
|
||||
for controller in devices.controllers:
|
||||
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
|
||||
for dev in devices:
|
||||
await dev.power_on()
|
||||
await devices[0].start_advertising(auto_restart=True, advertising_interval_min=1.0)
|
||||
|
||||
connection = await devices[1].connect(devices[0].random_address)
|
||||
await connection.disconnect()
|
||||
|
||||
await async_barrier()
|
||||
await async_barrier()
|
||||
|
||||
# a second connection attempt is working
|
||||
connection = await devices[1].connect(devices[0].random_address)
|
||||
await connection.disconnect()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_advertising_and_scanning():
|
||||
@@ -445,7 +466,9 @@ async def test_get_remote_le_features():
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
assert (await devices.connections[0].get_remote_le_features()) is not None
|
||||
assert (
|
||||
await devices.connections[0].get_remote_le_features()
|
||||
) == devices.controllers[1].le_features
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -232,6 +232,14 @@ def test_return_parameters() -> None:
|
||||
assert len(params.local_name) == 248
|
||||
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
|
||||
|
||||
# Some return parameters may be shorter than the full length
|
||||
# (for Command Complete events with errors)
|
||||
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
|
||||
bytes.fromhex('010011223344')
|
||||
)
|
||||
assert isinstance(params, hci.HCI_StatusReturnParameters)
|
||||
assert params.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_HCI_Command():
|
||||
|
||||
@@ -26,11 +26,14 @@ from bumble.controller import Controller
|
||||
from bumble.hci import (
|
||||
HCI_AclDataPacket,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_CommandStatus,
|
||||
HCI_Disconnect_Command,
|
||||
HCI_Error,
|
||||
HCI_ErrorCode,
|
||||
HCI_Event,
|
||||
HCI_GenericReturnParameters,
|
||||
HCI_LE_Terminate_BIG_Command,
|
||||
HCI_Reset_Command,
|
||||
HCI_StatusReturnParameters,
|
||||
)
|
||||
@@ -229,3 +232,47 @@ async def test_send_sync_command() -> None:
|
||||
)
|
||||
response3 = await host.send_sync_command_raw(command) # type: ignore
|
||||
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_async_command() -> None:
|
||||
source = Source()
|
||||
sink = Sink(
|
||||
source,
|
||||
HCI_Command_Status_Event(
|
||||
HCI_CommandStatus.PENDING,
|
||||
1,
|
||||
HCI_Reset_Command.op_code,
|
||||
),
|
||||
)
|
||||
|
||||
host = Host(source, sink)
|
||||
host.ready = True
|
||||
|
||||
# Normal pending status
|
||||
response = await host.send_async_command(
|
||||
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0)
|
||||
)
|
||||
assert response == HCI_CommandStatus.PENDING
|
||||
|
||||
# Unknown HCI command result returned as a Command Status
|
||||
sink.response = HCI_Command_Status_Event(
|
||||
HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR,
|
||||
1,
|
||||
HCI_LE_Terminate_BIG_Command.op_code,
|
||||
)
|
||||
response = await host.send_async_command(
|
||||
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
|
||||
)
|
||||
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
|
||||
|
||||
# Unknown HCI command result returned as a Command Complete
|
||||
sink.response = HCI_Command_Complete_Event(
|
||||
1,
|
||||
HCI_LE_Terminate_BIG_Command.op_code,
|
||||
HCI_StatusReturnParameters(HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR),
|
||||
)
|
||||
response = await host.send_async_command(
|
||||
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
|
||||
)
|
||||
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
|
||||
|
||||
@@ -29,8 +29,7 @@ from bumble.gatt import Characteristic, Service
|
||||
from bumble.hci import Role
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
from bumble.smp import (
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
SMP_PAIRING_NOT_SUPPORTED_ERROR,
|
||||
ErrorCode,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
)
|
||||
@@ -378,7 +377,7 @@ async def test_self_smp_reject():
|
||||
await _test_self_smp_with_configs(None, rejecting_pairing_config)
|
||||
paired = True
|
||||
except ProtocolError as error:
|
||||
assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
|
||||
assert error.error_code == ErrorCode.PAIRING_NOT_SUPPORTED
|
||||
|
||||
assert not paired
|
||||
|
||||
@@ -403,7 +402,7 @@ async def test_self_smp_wrong_pin():
|
||||
)
|
||||
paired = True
|
||||
except ProtocolError as error:
|
||||
assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
assert error.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||
|
||||
assert not paired
|
||||
|
||||
@@ -534,11 +533,11 @@ async def test_self_smp_oob_sc():
|
||||
|
||||
with pytest.raises(ProtocolError) as error:
|
||||
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
|
||||
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||
|
||||
with pytest.raises(ProtocolError):
|
||||
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
|
||||
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user