forked from auracaster/bumble_mirror
Compare commits
32 Commits
packageFil
...
gbg/consol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e4055bb6b | ||
|
|
4433184048 | ||
|
|
312fc8db36 | ||
|
|
615691ec81 | ||
|
|
ae8b83f294 | ||
|
|
4a8e21f4db | ||
|
|
3462e7c437 | ||
|
|
0f2e5239ad | ||
|
|
ee48cdc63f | ||
|
|
1c278bec93 | ||
|
|
6a51166af7 | ||
|
|
85d79fa914 | ||
|
|
142bdce94a | ||
|
|
881a5a64b5 | ||
|
|
5aae44b610 | ||
|
|
e3ea167827 | ||
|
|
eec145e095 | ||
|
|
87fa02d6e5 | ||
|
|
ad94c1e1f3 | ||
|
|
546a0bce8d | ||
|
|
cb7ca44a1c | ||
|
|
4081b93407 | ||
|
|
26203ebaad | ||
|
|
3389e3e1ed | ||
|
|
7e1f01c01e | ||
|
|
613e15548a | ||
|
|
e09c91df8e | ||
|
|
df206667b6 | ||
|
|
0f19dd5263 | ||
|
|
b98e4937f3 | ||
|
|
27791cf218 | ||
|
|
f8a2d4f0e0 |
30
.devcontainer/devcontainer.json
Normal file
30
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,30 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
||||
{
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
"postCreateCommand":
|
||||
"python -m pip install '.[build,test,development,documentation]'",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
"customizations": {
|
||||
// Configure properties specific to VS Code.
|
||||
"vscode": {
|
||||
// Add the IDs of extensions you want installed when the container is created.
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_Constant,
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_LE_2M_PHY,
|
||||
@@ -167,6 +168,7 @@ class ConsoleApp:
|
||||
'remote-services': None,
|
||||
'local-values': None,
|
||||
'remote-values': None,
|
||||
'remote-attributes': None,
|
||||
},
|
||||
'filter': {
|
||||
'address': None,
|
||||
@@ -215,6 +217,7 @@ class ConsoleApp:
|
||||
)
|
||||
self.local_values_text = FormattedTextControl()
|
||||
self.remote_values_text = FormattedTextControl()
|
||||
self.remote_attributes_text = FormattedTextControl()
|
||||
self.log_height = Dimension(min=7, weight=4)
|
||||
self.log_max_lines = 100
|
||||
self.log_lines = []
|
||||
@@ -241,6 +244,12 @@ class ConsoleApp:
|
||||
Frame(Window(self.remote_values_text), title='Remote Values'),
|
||||
filter=Condition(lambda: self.top_tab == 'remote-values'),
|
||||
),
|
||||
ConditionalContainer(
|
||||
Frame(
|
||||
Window(self.remote_attributes_text), title='Remote Attributes'
|
||||
),
|
||||
filter=Condition(lambda: self.top_tab == 'remote-attributes'),
|
||||
),
|
||||
ConditionalContainer(
|
||||
Frame(Window(self.log_text, height=self.log_height), title='Log'),
|
||||
filter=Condition(lambda: self.top_tab == 'log'),
|
||||
@@ -289,11 +298,7 @@ class ConsoleApp:
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
random_address = (
|
||||
f"{random.randint(192,255):02X}" # address is static random
|
||||
)
|
||||
for random_byte in random.sample(range(255), 5):
|
||||
random_address += f":{random_byte:02X}"
|
||||
random_address = Address.generate_static_address()
|
||||
self.append_to_log(f"Setting random address: {random_address}")
|
||||
self.device = Device.with_hci(
|
||||
'Bumble', random_address, hci_source, hci_sink
|
||||
@@ -503,19 +508,9 @@ class ConsoleApp:
|
||||
self.show_error('not connected')
|
||||
return
|
||||
|
||||
# Discover all services, characteristics and descriptors
|
||||
self.append_to_output('discovering services...')
|
||||
await self.connected_peer.discover_services()
|
||||
self.append_to_output(
|
||||
f'found {len(self.connected_peer.services)} services,'
|
||||
' discovering characteristics...'
|
||||
)
|
||||
await self.connected_peer.discover_characteristics()
|
||||
self.append_to_output('found characteristics, discovering descriptors...')
|
||||
for service in self.connected_peer.services:
|
||||
for characteristic in service.characteristics:
|
||||
await self.connected_peer.discover_descriptors(characteristic)
|
||||
self.append_to_output('discovery completed')
|
||||
self.append_to_output('Service Discovery starting...')
|
||||
await self.connected_peer.discover_all()
|
||||
self.append_to_output('Service Discovery done!')
|
||||
|
||||
self.show_remote_services(self.connected_peer.services)
|
||||
|
||||
@@ -529,7 +524,7 @@ class ConsoleApp:
|
||||
attributes = await self.connected_peer.discover_attributes()
|
||||
self.append_to_output(f'discovered {len(attributes)} attributes...')
|
||||
|
||||
self.show_attributes(attributes)
|
||||
await self.show_remote_attributes(attributes)
|
||||
|
||||
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
|
||||
if not self.connected_peer:
|
||||
@@ -674,7 +669,6 @@ class ConsoleApp:
|
||||
connection_parameters_preferences=connection_parameters_preferences,
|
||||
timeout=DEFAULT_CONNECTION_TIMEOUT,
|
||||
)
|
||||
self.top_tab = 'services'
|
||||
except bumble.core.TimeoutError:
|
||||
self.show_error('connection timed out')
|
||||
|
||||
@@ -745,19 +739,20 @@ class ConsoleApp:
|
||||
'remote-services',
|
||||
'local-values',
|
||||
'remote-values',
|
||||
'remote-attributes',
|
||||
}:
|
||||
self.top_tab = params[0]
|
||||
self.ui.invalidate()
|
||||
|
||||
while self.top_tab == 'local-values':
|
||||
await self.do_show_local_values()
|
||||
await self.show_local_values()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
while self.top_tab == 'remote-values':
|
||||
await self.do_show_remote_values()
|
||||
await self.show_remote_values()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def do_show_local_values(self):
|
||||
async def show_local_values(self):
|
||||
prettytable = PrettyTable()
|
||||
field_names = ["Service", "Characteristic", "Descriptor"]
|
||||
|
||||
@@ -812,7 +807,7 @@ class ConsoleApp:
|
||||
self.local_values_text.text = prettytable.get_string()
|
||||
self.ui.invalidate()
|
||||
|
||||
async def do_show_remote_values(self):
|
||||
async def show_remote_values(self):
|
||||
prettytable = PrettyTable(
|
||||
field_names=[
|
||||
"Connection",
|
||||
@@ -846,6 +841,23 @@ class ConsoleApp:
|
||||
self.remote_values_text.text = prettytable.get_string()
|
||||
self.ui.invalidate()
|
||||
|
||||
async def show_remote_attributes(self, attributes):
|
||||
lines = []
|
||||
for attribute in attributes:
|
||||
lines.append(('ansimagenta', str(attribute) + "\n"))
|
||||
try:
|
||||
value = await attribute.read_value()
|
||||
lines.append(('ansicyan', value.hex() + "\n"))
|
||||
except bumble.core.ProtocolError as error:
|
||||
lines.append(("ansired", f"!!! Protocol Error ({error})\n"))
|
||||
except bumble.core.TimeoutError:
|
||||
lines.append(("ansired", "!!! Timeout\n"))
|
||||
except Exception as error:
|
||||
lines.append(("ansired", f"!!! Error ({error})\n"))
|
||||
|
||||
self.remote_attributes_text.text = lines
|
||||
self.ui.invalidate()
|
||||
|
||||
async def do_get_phy(self, _):
|
||||
if not self.connected_peer:
|
||||
self.show_error('not connected')
|
||||
|
||||
18
bumble/at.py
18
bumble/at.py
@@ -14,13 +14,19 @@
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
class AtParsingError(core.InvalidPacketError):
|
||||
"""Error raised when parsing AT commands fails."""
|
||||
|
||||
|
||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
"""Split input parameters into tokens.
|
||||
Removes space characters outside of double quote blocks:
|
||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||
are ignored [..], unless they are embedded in numeric or string constants"
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = []
|
||||
in_quotes = False
|
||||
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
token = bytearray()
|
||||
elif char == b'(':
|
||||
if len(token) > 0:
|
||||
raise ValueError("open_paren following regular character")
|
||||
raise AtParsingError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
elif char == b'"':
|
||||
if len(token) > 0:
|
||||
raise ValueError("quote following regular character")
|
||||
raise AtParsingError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
else:
|
||||
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
|
||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
"""Parse the parameters using the comma and parenthesis separators.
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = tokenize_parameters(buffer)
|
||||
accumulator: List[list] = [[]]
|
||||
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
accumulator.append([])
|
||||
elif token == b')':
|
||||
if len(accumulator) < 2:
|
||||
raise ValueError("close_paren without matching open_paren")
|
||||
raise AtParsingError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
else:
|
||||
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
|
||||
accumulator[-1].append(current)
|
||||
if len(accumulator) > 1:
|
||||
raise ValueError("missing close_paren")
|
||||
raise AtParsingError("missing close_paren")
|
||||
return accumulator[0]
|
||||
|
||||
@@ -20,6 +20,7 @@ import enum
|
||||
import struct
|
||||
from typing import Dict, Type, Union, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
|
||||
@@ -88,7 +89,9 @@ class Frame:
|
||||
short_name = subclass.__name__.replace("ResponseFrame", "")
|
||||
category_class = ResponseFrame
|
||||
else:
|
||||
raise ValueError(f"invalid subclass name {subclass.__name__}")
|
||||
raise core.InvalidArgumentError(
|
||||
f"invalid subclass name {subclass.__name__}"
|
||||
)
|
||||
|
||||
uppercase_indexes = [
|
||||
i for i in range(len(short_name)) if short_name[i].isupper()
|
||||
@@ -106,7 +109,7 @@ class Frame:
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Frame:
|
||||
if data[0] >> 4 != 0:
|
||||
raise ValueError("first 4 bits must be 0s")
|
||||
raise core.InvalidPacketError("first 4 bits must be 0s")
|
||||
|
||||
ctype_or_response = data[0] & 0xF
|
||||
subunit_type = Frame.SubunitType(data[1] >> 3)
|
||||
@@ -122,7 +125,7 @@ class Frame:
|
||||
# Extended to the next byte
|
||||
extension = data[2]
|
||||
if extension == 0:
|
||||
raise ValueError("extended subunit ID value reserved")
|
||||
raise core.InvalidPacketError("extended subunit ID value reserved")
|
||||
if extension == 0xFF:
|
||||
subunit_id = 5 + 254 + data[3]
|
||||
opcode_offset = 4
|
||||
@@ -131,7 +134,7 @@ class Frame:
|
||||
opcode_offset = 3
|
||||
|
||||
elif subunit_id == 6:
|
||||
raise ValueError("reserved subunit ID")
|
||||
raise core.InvalidPacketError("reserved subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
@@ -448,7 +451,7 @@ class PassThroughFrame:
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
if len(operation_data) > 255:
|
||||
raise ValueError("operation data must be <= 255 bytes")
|
||||
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
|
||||
self.state_flag = state_flag
|
||||
self.operation_id = operation_id
|
||||
self.operation_data = operation_data
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import core
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -275,7 +276,7 @@ class Protocol:
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||
raise ValueError("command handler not registered")
|
||||
raise core.InvalidArgumentError("command handler not registered")
|
||||
del self.command_handlers[pid]
|
||||
|
||||
def register_response_handler(
|
||||
@@ -287,5 +288,5 @@ class Protocol:
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||
raise ValueError("response handler not registered")
|
||||
raise core.InvalidArgumentError("response handler not registered")
|
||||
del self.response_handlers[pid]
|
||||
|
||||
@@ -43,6 +43,7 @@ from .core import (
|
||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||
InvalidStateError,
|
||||
ProtocolError,
|
||||
InvalidArgumentError,
|
||||
name_or_number,
|
||||
)
|
||||
from .a2dp import (
|
||||
@@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
|
||||
signal_identifier_str = name[:-7]
|
||||
message_type = Message.MessageType.RESPONSE_REJECT
|
||||
else:
|
||||
raise ValueError('invalid class name')
|
||||
raise InvalidArgumentError('invalid class name')
|
||||
|
||||
subclass.message_type = message_type
|
||||
|
||||
@@ -2162,6 +2163,9 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
|
||||
def on_abort_command(self):
|
||||
self.emit('abort')
|
||||
|
||||
def on_delayreport_command(self, delay: int):
|
||||
self.emit('delay_report', delay)
|
||||
|
||||
def on_rtp_channel_open(self):
|
||||
self.emit('rtp_channel_open')
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ from bumble.sdp import (
|
||||
)
|
||||
from bumble.utils import AsyncRunner, OpenIntEnum
|
||||
from bumble.core import (
|
||||
InvalidArgumentError,
|
||||
ProtocolError,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_AVCTP_PROTOCOL_ID,
|
||||
@@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter):
|
||||
def notify_track_changed(self, identifier: bytes) -> None:
|
||||
"""Notify the connected peer of a Track change."""
|
||||
if len(identifier) != 8:
|
||||
raise ValueError("identifier must be 8 bytes")
|
||||
raise InvalidArgumentError("identifier must be 8 bytes")
|
||||
self.notify_event(TrackChangedEvent(identifier))
|
||||
|
||||
def notify_playback_position_changed(self, position: int) -> None:
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BitReader:
|
||||
@@ -40,7 +42,7 @@ class BitReader:
|
||||
""" "Read up to 32 bits."""
|
||||
|
||||
if bits > 32:
|
||||
raise ValueError('maximum read size is 32')
|
||||
raise core.InvalidArgumentError('maximum read size is 32')
|
||||
|
||||
if self.bits_cached >= bits:
|
||||
# We have enough bits.
|
||||
@@ -53,7 +55,7 @@ class BitReader:
|
||||
feed_size = len(feed_bytes)
|
||||
feed_int = int.from_bytes(feed_bytes, byteorder='big')
|
||||
if 8 * feed_size + self.bits_cached < bits:
|
||||
raise ValueError('trying to read past the data')
|
||||
raise core.InvalidArgumentError('trying to read past the data')
|
||||
self.byte_position += feed_size
|
||||
|
||||
# Combine the new cache and the old cache
|
||||
@@ -68,7 +70,7 @@ class BitReader:
|
||||
|
||||
def read_bytes(self, count: int):
|
||||
if self.bit_position + 8 * count > 8 * len(self.data):
|
||||
raise ValueError('not enough data')
|
||||
raise core.InvalidArgumentError('not enough data')
|
||||
|
||||
if self.bit_position % 8:
|
||||
# Not byte aligned
|
||||
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
|
||||
|
||||
@staticmethod
|
||||
def program_config_element(reader: BitReader):
|
||||
raise ValueError('program_config_element not supported')
|
||||
raise core.InvalidPacketError('program_config_element not supported')
|
||||
|
||||
@dataclass
|
||||
class GASpecificConfig:
|
||||
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
|
||||
aac_spectral_data_resilience_flags = reader.read(1)
|
||||
extension_flag_3 = reader.read(1)
|
||||
if extension_flag_3 == 1:
|
||||
raise ValueError('extensionFlag3 == 1 not supported')
|
||||
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
|
||||
|
||||
@staticmethod
|
||||
def audio_object_type(reader: BitReader):
|
||||
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
|
||||
reader, self.channel_configuration, self.audio_object_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise core.InvalidPacketError(
|
||||
f'audioObjectType {self.audio_object_type} not supported'
|
||||
)
|
||||
|
||||
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
|
||||
else:
|
||||
audio_mux_version_a = 0
|
||||
if audio_mux_version_a != 0:
|
||||
raise ValueError('audioMuxVersionA != 0 not supported')
|
||||
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
|
||||
if audio_mux_version == 1:
|
||||
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
||||
stream_cnt = 0
|
||||
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
|
||||
num_sub_frames = reader.read(6)
|
||||
num_program = reader.read(4)
|
||||
if num_program != 0:
|
||||
raise ValueError('num_program != 0 not supported')
|
||||
raise core.InvalidPacketError('num_program != 0 not supported')
|
||||
num_layer = reader.read(3)
|
||||
if num_layer != 0:
|
||||
raise ValueError('num_layer != 0 not supported')
|
||||
raise core.InvalidPacketError('num_layer != 0 not supported')
|
||||
if audio_mux_version == 0:
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
reader
|
||||
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
|
||||
)
|
||||
audio_specific_config_len = reader.bit_position - marker
|
||||
if asc_len < audio_specific_config_len:
|
||||
raise ValueError('audio_specific_config_len > asc_len')
|
||||
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
|
||||
asc_len -= audio_specific_config_len
|
||||
reader.skip(asc_len)
|
||||
frame_length_type = reader.read(3)
|
||||
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
|
||||
elif frame_length_type == 1:
|
||||
frame_length = reader.read(9)
|
||||
else:
|
||||
raise ValueError(f'frame_length_type {frame_length_type} not supported')
|
||||
raise core.InvalidPacketError(
|
||||
f'frame_length_type {frame_length_type} not supported'
|
||||
)
|
||||
|
||||
self.other_data_present = reader.read(1)
|
||||
if self.other_data_present:
|
||||
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
|
||||
|
||||
def __init__(self, reader: BitReader, mux_config_present: int):
|
||||
if mux_config_present == 0:
|
||||
raise ValueError('muxConfigPresent == 0 not supported')
|
||||
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
|
||||
|
||||
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
||||
use_same_stream_mux = reader.read(1)
|
||||
if use_same_stream_mux:
|
||||
raise ValueError('useSameStreamMux == 1 not supported')
|
||||
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
|
||||
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
||||
|
||||
# We only support:
|
||||
|
||||
@@ -16,6 +16,10 @@ from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class ColorError(ValueError):
|
||||
"""Error raised when a color spec is invalid."""
|
||||
|
||||
|
||||
# ANSI color names. There is also a "default"
|
||||
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
|
||||
|
||||
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
|
||||
elif isinstance(spec, int) and 0 <= spec <= 255:
|
||||
return _join(base + 8, 5, spec)
|
||||
else:
|
||||
raise ValueError('Invalid color spec "%s"' % spec)
|
||||
raise ColorError('Invalid color spec "%s"' % spec)
|
||||
|
||||
|
||||
def color(
|
||||
@@ -72,7 +76,7 @@ def color(
|
||||
if style_part in STYLES:
|
||||
codes.append(STYLES.index(style_part))
|
||||
else:
|
||||
raise ValueError('Invalid style "%s"' % style_part)
|
||||
raise ColorError('Invalid style "%s"' % style_part)
|
||||
|
||||
if codes:
|
||||
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
|
||||
|
||||
@@ -79,7 +79,13 @@ def get_dict_key_by_value(dictionary, value):
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# -----------------------------------------------------------------------------
|
||||
class BaseError(Exception):
|
||||
|
||||
|
||||
class BaseBumbleError(Exception):
|
||||
"""Base Error raised by Bumble."""
|
||||
|
||||
|
||||
class BaseError(BaseBumbleError):
|
||||
"""Base class for errors with an error code, error name and namespace"""
|
||||
|
||||
def __init__(
|
||||
@@ -118,18 +124,38 @@ class ProtocolError(BaseError):
|
||||
"""Protocol Error"""
|
||||
|
||||
|
||||
class TimeoutError(Exception): # pylint: disable=redefined-builtin
|
||||
class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin
|
||||
"""Timeout Error"""
|
||||
|
||||
|
||||
class CommandTimeoutError(Exception):
|
||||
class CommandTimeoutError(BaseBumbleError):
|
||||
"""Command Timeout Error"""
|
||||
|
||||
|
||||
class InvalidStateError(Exception):
|
||||
class InvalidStateError(BaseBumbleError):
|
||||
"""Invalid State Error"""
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseBumbleError, ValueError):
|
||||
"""Invalid Argument Error"""
|
||||
|
||||
|
||||
class InvalidPacketError(BaseBumbleError, ValueError):
|
||||
"""Invalid Packet Error"""
|
||||
|
||||
|
||||
class InvalidOperationError(BaseBumbleError, RuntimeError):
|
||||
"""Invalid Operation Error"""
|
||||
|
||||
|
||||
class OutOfResourcesError(BaseBumbleError, RuntimeError):
|
||||
"""Out of Resources Error"""
|
||||
|
||||
|
||||
class UnreachableError(BaseBumbleError):
|
||||
"""The code path raising this error should be unreachable."""
|
||||
|
||||
|
||||
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
|
||||
"""Connection Error"""
|
||||
|
||||
@@ -188,12 +214,12 @@ class UUID:
|
||||
or uuid_str_or_int[18] != '-'
|
||||
or uuid_str_or_int[23] != '-'
|
||||
):
|
||||
raise ValueError('invalid UUID format')
|
||||
raise InvalidArgumentError('invalid UUID format')
|
||||
uuid_str = uuid_str_or_int.replace('-', '')
|
||||
else:
|
||||
uuid_str = uuid_str_or_int
|
||||
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
|
||||
raise ValueError(f"invalid UUID format: {uuid_str}")
|
||||
raise InvalidArgumentError(f"invalid UUID format: {uuid_str}")
|
||||
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
||||
self.name = name
|
||||
|
||||
@@ -218,7 +244,7 @@ class UUID:
|
||||
|
||||
return self.register()
|
||||
|
||||
raise ValueError('only 2, 4 and 16 bytes are allowed')
|
||||
raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
|
||||
|
||||
@classmethod
|
||||
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
|
||||
|
||||
215
bumble/device.py
215
bumble/device.py
@@ -27,6 +27,7 @@ import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, IntEnum
|
||||
import functools
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
@@ -178,10 +179,16 @@ from .core import (
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
AdvertisingData,
|
||||
BaseBumbleError,
|
||||
ConnectionParameterUpdateError,
|
||||
CommandTimeoutError,
|
||||
ConnectionParameters,
|
||||
ConnectionPHY,
|
||||
InvalidArgumentError,
|
||||
InvalidOperationError,
|
||||
InvalidStateError,
|
||||
OutOfResourcesError,
|
||||
UnreachableError,
|
||||
)
|
||||
from .utils import (
|
||||
AsyncRunner,
|
||||
@@ -253,8 +260,9 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
|
||||
DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
|
||||
)
|
||||
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
|
||||
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
|
||||
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
|
||||
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -266,6 +274,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class ObjectLookupError(BaseBumbleError):
|
||||
"""Error raised when failed to lookup an object."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1133,6 +1143,15 @@ class Peer:
|
||||
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
|
||||
return await self.gatt_client.discover_attributes()
|
||||
|
||||
async def discover_all(self):
|
||||
await self.discover_services()
|
||||
for service in self.services:
|
||||
await self.discover_characteristics(service=service)
|
||||
|
||||
for service in self.services:
|
||||
for characteristic in service.characteristics:
|
||||
await self.discover_descriptors(characteristic=characteristic)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
characteristic: gatt_client.CharacteristicProxy,
|
||||
@@ -1172,8 +1191,20 @@ class Peer:
|
||||
return self.gatt_client.get_services_by_uuid(uuid)
|
||||
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
|
||||
self,
|
||||
uuid: core.UUID,
|
||||
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
|
||||
) -> List[gatt_client.CharacteristicProxy]:
|
||||
if isinstance(service, core.UUID):
|
||||
return list(
|
||||
itertools.chain(
|
||||
*[
|
||||
self.get_characteristics_by_uuid(uuid, s)
|
||||
for s in self.get_services_by_uuid(service)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
|
||||
|
||||
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
|
||||
@@ -1274,6 +1305,7 @@ class Connection(CompositeEventEmitter):
|
||||
handle: int
|
||||
transport: int
|
||||
self_address: Address
|
||||
self_resolvable_address: Optional[Address]
|
||||
peer_address: Address
|
||||
peer_resolvable_address: Optional[Address]
|
||||
peer_le_features: Optional[LeFeatureMask]
|
||||
@@ -1321,6 +1353,7 @@ class Connection(CompositeEventEmitter):
|
||||
handle,
|
||||
transport,
|
||||
self_address,
|
||||
self_resolvable_address,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
@@ -1332,6 +1365,7 @@ class Connection(CompositeEventEmitter):
|
||||
self.handle = handle
|
||||
self.transport = transport
|
||||
self.self_address = self_address
|
||||
self.self_resolvable_address = self_resolvable_address
|
||||
self.peer_address = peer_address
|
||||
self.peer_resolvable_address = peer_resolvable_address
|
||||
self.peer_name = None # Classic only
|
||||
@@ -1365,6 +1399,7 @@ class Connection(CompositeEventEmitter):
|
||||
None,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
device.public_address,
|
||||
None,
|
||||
peer_address,
|
||||
None,
|
||||
role,
|
||||
@@ -1523,7 +1558,9 @@ class Connection(CompositeEventEmitter):
|
||||
f'Connection(handle=0x{self.handle:04X}, '
|
||||
f'role={self.role_name}, '
|
||||
f'self_address={self.self_address}, '
|
||||
f'peer_address={self.peer_address})'
|
||||
f'self_resolvable_address={self.self_resolvable_address}, '
|
||||
f'peer_address={self.peer_address}, '
|
||||
f'peer_resolvable_address={self.peer_resolvable_address})'
|
||||
)
|
||||
|
||||
|
||||
@@ -1538,8 +1575,9 @@ class DeviceConfiguration:
|
||||
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
le_enabled: bool = True
|
||||
# LE host enable 2nd parameter
|
||||
le_simultaneous_enabled: bool = False
|
||||
le_privacy_enabled: bool = False
|
||||
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
|
||||
classic_enabled: bool = False
|
||||
classic_sc_enabled: bool = True
|
||||
classic_ssp_enabled: bool = True
|
||||
@@ -1555,6 +1593,7 @@ class DeviceConfiguration:
|
||||
irk: bytes = bytes(16) # This really must be changed for any level of security
|
||||
keystore: Optional[str] = None
|
||||
address_resolution_offload: bool = False
|
||||
address_generation_offload: bool = False
|
||||
cis_enabled: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -1640,7 +1679,9 @@ def with_connection_from_handle(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, connection_handle, *args, **kwargs):
|
||||
if (connection := self.lookup_connection(connection_handle)) is None:
|
||||
raise ValueError(f'no connection for handle: 0x{connection_handle:04x}')
|
||||
raise ObjectLookupError(
|
||||
f'no connection for handle: 0x{connection_handle:04x}'
|
||||
)
|
||||
return function(self, connection, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -1655,7 +1696,7 @@ def with_connection_from_address(function):
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, *args, **kwargs)
|
||||
raise ValueError('no connection for address')
|
||||
raise ObjectLookupError('no connection for address')
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1705,8 +1746,9 @@ device_host_event_handlers: List[str] = []
|
||||
# -----------------------------------------------------------------------------
|
||||
class Device(CompositeEventEmitter):
|
||||
# Incomplete list of fields.
|
||||
random_address: Address
|
||||
public_address: Address
|
||||
random_address: Address # Random address that may change with RPA
|
||||
public_address: Address # Public address (obtained from the controller)
|
||||
static_address: Address # Random address that can be set but does not change
|
||||
classic_enabled: bool
|
||||
name: str
|
||||
class_of_device: int
|
||||
@@ -1836,15 +1878,19 @@ class Device(CompositeEventEmitter):
|
||||
config = config or DeviceConfiguration()
|
||||
self.config = config
|
||||
|
||||
self.public_address = Address('00:00:00:00:00:00')
|
||||
self.name = config.name
|
||||
self.public_address = Address.ANY
|
||||
self.random_address = config.address
|
||||
self.static_address = config.address
|
||||
self.class_of_device = config.class_of_device
|
||||
self.keystore = None
|
||||
self.irk = config.irk
|
||||
self.le_enabled = config.le_enabled
|
||||
self.classic_enabled = config.classic_enabled
|
||||
self.le_simultaneous_enabled = config.le_simultaneous_enabled
|
||||
self.le_privacy_enabled = config.le_privacy_enabled
|
||||
self.le_rpa_timeout = config.le_rpa_timeout
|
||||
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
|
||||
self.classic_enabled = config.classic_enabled
|
||||
self.cis_enabled = config.cis_enabled
|
||||
self.classic_sc_enabled = config.classic_sc_enabled
|
||||
self.classic_ssp_enabled = config.classic_ssp_enabled
|
||||
@@ -1853,6 +1899,7 @@ class Device(CompositeEventEmitter):
|
||||
self.connectable = config.connectable
|
||||
self.classic_accept_any = config.classic_accept_any
|
||||
self.address_resolution_offload = config.address_resolution_offload
|
||||
self.address_generation_offload = config.address_generation_offload
|
||||
|
||||
# Extended advertising.
|
||||
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
|
||||
@@ -1908,6 +1955,7 @@ class Device(CompositeEventEmitter):
|
||||
if isinstance(address, str):
|
||||
address = Address(address)
|
||||
self.random_address = address
|
||||
self.static_address = address
|
||||
|
||||
# Setup SMP
|
||||
self.smp_manager = smp.Manager(
|
||||
@@ -2093,7 +2141,7 @@ class Device(CompositeEventEmitter):
|
||||
spec=spec,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unexpected mode {spec}')
|
||||
raise InvalidArgumentError(f'Unexpected mode {spec}')
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||
@@ -2139,6 +2187,16 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
|
||||
if self.le_enabled:
|
||||
# If LE Privacy is enabled, generate an RPA
|
||||
if self.le_privacy_enabled:
|
||||
self.random_address = Address.generate_private_address(self.irk)
|
||||
logger.info(f'Initial RPA: {self.random_address}')
|
||||
if self.le_rpa_timeout > 0:
|
||||
# Start a task to periodically generate a new RPA
|
||||
self.le_rpa_periodic_update_task = asyncio.create_task(
|
||||
self._run_rpa_periodic_update()
|
||||
)
|
||||
|
||||
# Set the controller address
|
||||
if self.random_address == Address.ANY_RANDOM:
|
||||
# Try to use an address generated at random by the controller
|
||||
@@ -2218,9 +2276,45 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def power_off(self) -> None:
|
||||
if self.powered_on:
|
||||
if self.le_rpa_periodic_update_task:
|
||||
self.le_rpa_periodic_update_task.cancel()
|
||||
|
||||
await self.host.flush()
|
||||
|
||||
self.powered_on = False
|
||||
|
||||
async def update_rpa(self) -> bool:
|
||||
"""
|
||||
Try to update the RPA.
|
||||
|
||||
Returns:
|
||||
True if the RPA was updated, False if it could not be updated.
|
||||
"""
|
||||
|
||||
# Check if this is a good time to rotate the address
|
||||
if self.is_advertising or self.is_scanning or self.is_le_connecting:
|
||||
logger.debug('skipping RPA update')
|
||||
return False
|
||||
|
||||
random_address = Address.generate_private_address(self.irk)
|
||||
response = await self.send_command(
|
||||
HCI_LE_Set_Random_Address_Command(random_address=self.random_address)
|
||||
)
|
||||
if response.return_parameters == HCI_SUCCESS:
|
||||
logger.info(f'new RPA: {random_address}')
|
||||
self.random_address = random_address
|
||||
return True
|
||||
else:
|
||||
logger.warning(f'failed to set RPA: {response.return_parameters}')
|
||||
return False
|
||||
|
||||
async def _run_rpa_periodic_update(self) -> None:
|
||||
"""Update the RPA periodically"""
|
||||
while self.le_rpa_timeout != 0:
|
||||
await asyncio.sleep(self.le_rpa_timeout)
|
||||
if not self.update_rpa():
|
||||
logger.debug("periodic RPA update failed")
|
||||
|
||||
async def refresh_resolving_list(self) -> None:
|
||||
assert self.keystore is not None
|
||||
|
||||
@@ -2228,7 +2322,7 @@ class Device(CompositeEventEmitter):
|
||||
# Create a host-side address resolver
|
||||
self.address_resolver = smp.AddressResolver(resolving_keys)
|
||||
|
||||
if self.address_resolution_offload:
|
||||
if self.address_resolution_offload or self.address_generation_offload:
|
||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
|
||||
|
||||
# Add an empty entry for non-directed address generation.
|
||||
@@ -2254,7 +2348,7 @@ class Device(CompositeEventEmitter):
|
||||
def supports_le_features(self, feature: LeFeatureMask) -> bool:
|
||||
return self.host.supports_le_features(feature)
|
||||
|
||||
def supports_le_phy(self, phy):
|
||||
def supports_le_phy(self, phy: int) -> bool:
|
||||
if phy == HCI_LE_1M_PHY:
|
||||
return True
|
||||
|
||||
@@ -2263,7 +2357,7 @@ class Device(CompositeEventEmitter):
|
||||
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
|
||||
}
|
||||
if phy not in feature_map:
|
||||
raise ValueError('invalid PHY')
|
||||
raise InvalidArgumentError('invalid PHY')
|
||||
|
||||
return self.supports_le_features(feature_map[phy])
|
||||
|
||||
@@ -2323,7 +2417,7 @@ class Device(CompositeEventEmitter):
|
||||
# Decide what peer address to use
|
||||
if advertising_type.is_directed:
|
||||
if target is None:
|
||||
raise ValueError('directed advertising requires a target')
|
||||
raise InvalidArgumentError('directed advertising requires a target')
|
||||
peer_address = target
|
||||
else:
|
||||
peer_address = Address.ANY
|
||||
@@ -2430,7 +2524,7 @@ class Device(CompositeEventEmitter):
|
||||
and advertising_data
|
||||
and scan_response_data
|
||||
):
|
||||
raise ValueError(
|
||||
raise InvalidArgumentError(
|
||||
"Extended advertisements can't have both data and scan \
|
||||
response data"
|
||||
)
|
||||
@@ -2446,7 +2540,9 @@ class Device(CompositeEventEmitter):
|
||||
if handle not in self.extended_advertising_sets
|
||||
)
|
||||
except StopIteration as exc:
|
||||
raise RuntimeError("all valid advertising handles already in use") from exc
|
||||
raise OutOfResourcesError(
|
||||
"all valid advertising handles already in use"
|
||||
) from exc
|
||||
|
||||
# Use the device's random address if a random address is needed but none was
|
||||
# provided.
|
||||
@@ -2545,14 +2641,14 @@ class Device(CompositeEventEmitter):
|
||||
) -> None:
|
||||
# Check that the arguments are legal
|
||||
if scan_interval < scan_window:
|
||||
raise ValueError('scan_interval must be >= scan_window')
|
||||
raise InvalidArgumentError('scan_interval must be >= scan_window')
|
||||
if (
|
||||
scan_interval < DEVICE_MIN_SCAN_INTERVAL
|
||||
or scan_interval > DEVICE_MAX_SCAN_INTERVAL
|
||||
):
|
||||
raise ValueError('scan_interval out of range')
|
||||
raise InvalidArgumentError('scan_interval out of range')
|
||||
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
|
||||
raise ValueError('scan_interval out of range')
|
||||
raise InvalidArgumentError('scan_interval out of range')
|
||||
|
||||
# Reset the accumulators
|
||||
self.advertisement_accumulators = {}
|
||||
@@ -2580,7 +2676,7 @@ class Device(CompositeEventEmitter):
|
||||
scanning_phy_count += 1
|
||||
|
||||
if scanning_phy_count == 0:
|
||||
raise ValueError('at least one scanning PHY must be enabled')
|
||||
raise InvalidArgumentError('at least one scanning PHY must be enabled')
|
||||
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Extended_Scan_Parameters_Command(
|
||||
@@ -2884,7 +2980,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# Check parameters
|
||||
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
|
||||
raise ValueError('invalid transport')
|
||||
raise InvalidArgumentError('invalid transport')
|
||||
|
||||
# Adjust the transport automatically if we need to
|
||||
if transport == BT_LE_TRANSPORT and not self.le_enabled:
|
||||
@@ -2901,7 +2997,7 @@ class Device(CompositeEventEmitter):
|
||||
peer_address = Address.from_string_for_transport(
|
||||
peer_address, transport
|
||||
)
|
||||
except ValueError:
|
||||
except InvalidArgumentError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(
|
||||
@@ -2913,7 +3009,7 @@ class Device(CompositeEventEmitter):
|
||||
transport == BT_BR_EDR_TRANSPORT
|
||||
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
|
||||
):
|
||||
raise ValueError('BR/EDR addresses must be PUBLIC')
|
||||
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
|
||||
|
||||
assert isinstance(peer_address, Address)
|
||||
|
||||
@@ -2964,7 +3060,7 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
)
|
||||
if not phys:
|
||||
raise ValueError('at least one supported PHY needed')
|
||||
raise InvalidArgumentError('at least one supported PHY needed')
|
||||
|
||||
phy_count = len(phys)
|
||||
initiating_phys = phy_list_to_bits(phys)
|
||||
@@ -3036,7 +3132,7 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
else:
|
||||
if HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||
raise ValueError('1M PHY preferences required')
|
||||
raise InvalidArgumentError('1M PHY preferences required')
|
||||
|
||||
prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
|
||||
result = await self.send_command(
|
||||
@@ -3136,7 +3232,7 @@ class Device(CompositeEventEmitter):
|
||||
if isinstance(peer_address, str):
|
||||
try:
|
||||
peer_address = Address(peer_address)
|
||||
except ValueError:
|
||||
except InvalidArgumentError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(
|
||||
@@ -3146,7 +3242,7 @@ class Device(CompositeEventEmitter):
|
||||
assert isinstance(peer_address, Address)
|
||||
|
||||
if peer_address == Address.NIL:
|
||||
raise ValueError('accept on nil address')
|
||||
raise InvalidArgumentError('accept on nil address')
|
||||
|
||||
# Create a future so that we can wait for the request
|
||||
pending_request_fut = asyncio.get_running_loop().create_future()
|
||||
@@ -3259,7 +3355,7 @@ class Device(CompositeEventEmitter):
|
||||
if isinstance(peer_address, str):
|
||||
try:
|
||||
peer_address = Address(peer_address)
|
||||
except ValueError:
|
||||
except InvalidArgumentError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(
|
||||
@@ -3302,10 +3398,10 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
|
||||
if tx_octets < 0x001B or tx_octets > 0x00FB:
|
||||
raise ValueError('tx_octets must be between 0x001B and 0x00FB')
|
||||
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
|
||||
|
||||
if tx_time < 0x0148 or tx_time > 0x4290:
|
||||
raise ValueError('tx_time must be between 0x0148 and 0x4290')
|
||||
raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')
|
||||
|
||||
return await self.send_command(
|
||||
HCI_LE_Set_Data_Length_Command(
|
||||
@@ -3580,7 +3676,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def encrypt(self, connection, enable=True):
|
||||
if not enable and connection.transport == BT_LE_TRANSPORT:
|
||||
raise ValueError('`enable` parameter is classic only.')
|
||||
raise InvalidArgumentError('`enable` parameter is classic only.')
|
||||
|
||||
# Set up event handlers
|
||||
pending_encryption = asyncio.get_running_loop().create_future()
|
||||
@@ -3599,11 +3695,11 @@ class Device(CompositeEventEmitter):
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
# Look for a key in the key store
|
||||
if self.keystore is None:
|
||||
raise RuntimeError('no key store')
|
||||
raise InvalidOperationError('no key store')
|
||||
|
||||
keys = await self.keystore.get(str(connection.peer_address))
|
||||
if keys is None:
|
||||
raise RuntimeError('keys not found in key store')
|
||||
raise InvalidOperationError('keys not found in key store')
|
||||
|
||||
if keys.ltk is not None:
|
||||
ltk = keys.ltk.value
|
||||
@@ -3614,7 +3710,7 @@ class Device(CompositeEventEmitter):
|
||||
rand = keys.ltk_central.rand
|
||||
ediv = keys.ltk_central.ediv
|
||||
else:
|
||||
raise RuntimeError('no LTK found for peer')
|
||||
raise InvalidOperationError('no LTK found for peer')
|
||||
|
||||
if connection.role != HCI_CENTRAL_ROLE:
|
||||
raise InvalidStateError('only centrals can start encryption')
|
||||
@@ -3889,7 +3985,7 @@ class Device(CompositeEventEmitter):
|
||||
return cis_link
|
||||
|
||||
# Mypy believes this is reachable when context is an ExitStack.
|
||||
raise InvalidStateError('Unreachable')
|
||||
raise UnreachableError()
|
||||
|
||||
# [LE only]
|
||||
@experimental('Only for testing.')
|
||||
@@ -4071,12 +4167,14 @@ class Device(CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
def on_connection(
|
||||
self,
|
||||
connection_handle,
|
||||
transport,
|
||||
peer_address,
|
||||
role,
|
||||
connection_parameters,
|
||||
):
|
||||
connection_handle: int,
|
||||
transport: int,
|
||||
peer_address: Address,
|
||||
self_resolvable_address: Optional[Address],
|
||||
peer_resolvable_address: Optional[Address],
|
||||
role: int,
|
||||
connection_parameters: ConnectionParameters,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'*** Connection: [0x{connection_handle:04X}] '
|
||||
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
|
||||
@@ -4097,15 +4195,15 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
return
|
||||
|
||||
# Resolve the peer address if we can
|
||||
peer_resolvable_address = None
|
||||
if self.address_resolver:
|
||||
if peer_address.is_resolvable:
|
||||
resolved_address = self.address_resolver.resolve(peer_address)
|
||||
if resolved_address is not None:
|
||||
logger.debug(f'*** Address resolved as {resolved_address}')
|
||||
peer_resolvable_address = peer_address
|
||||
peer_address = resolved_address
|
||||
if peer_resolvable_address is None:
|
||||
# Resolve the peer address if we can
|
||||
if self.address_resolver:
|
||||
if peer_address.is_resolvable:
|
||||
resolved_address = self.address_resolver.resolve(peer_address)
|
||||
if resolved_address is not None:
|
||||
logger.debug(f'*** Address resolved as {resolved_address}')
|
||||
peer_resolvable_address = peer_address
|
||||
peer_address = resolved_address
|
||||
|
||||
self_address = None
|
||||
if role == HCI_CENTRAL_ROLE:
|
||||
@@ -4136,12 +4234,19 @@ class Device(CompositeEventEmitter):
|
||||
else self.random_address
|
||||
)
|
||||
|
||||
# Convert all-zeros addresses into None.
|
||||
if self_resolvable_address == Address.ANY_RANDOM:
|
||||
self_resolvable_address = None
|
||||
if peer_resolvable_address == Address.ANY_RANDOM:
|
||||
peer_resolvable_address = None
|
||||
|
||||
# Create a connection.
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
transport,
|
||||
self_address,
|
||||
self_resolvable_address,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
@@ -4152,9 +4257,10 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
|
||||
if self.legacy_advertiser.auto_restart:
|
||||
advertiser = self.legacy_advertiser
|
||||
connection.once(
|
||||
'disconnection',
|
||||
lambda _: self.abort_on('flush', self.legacy_advertiser.start()),
|
||||
lambda _: self.abort_on('flush', advertiser.start()),
|
||||
)
|
||||
else:
|
||||
self.legacy_advertiser = None
|
||||
@@ -4377,7 +4483,7 @@ class Device(CompositeEventEmitter):
|
||||
return await pairing_config.delegate.confirm(auto=True)
|
||||
|
||||
async def na() -> bool:
|
||||
assert False, "N/A: unreachable"
|
||||
raise UnreachableError()
|
||||
|
||||
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
|
||||
methods = {
|
||||
@@ -4838,5 +4944,6 @@ class Device(CompositeEventEmitter):
|
||||
return (
|
||||
f'Device(name="{self.name}", '
|
||||
f'random_address="{self.random_address}", '
|
||||
f'public_address="{self.public_address}")'
|
||||
f'public_address="{self.public_address}", '
|
||||
f'static_address="{self.static_address}")'
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ from typing import Tuple
|
||||
import weakref
|
||||
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code,
|
||||
STATUS_SPEC,
|
||||
@@ -49,6 +50,10 @@ from bumble.drivers import common
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RtkFirmwareError(core.BaseBumbleError):
|
||||
"""Error raised when RTK firmware initialization fails."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -208,15 +213,15 @@ class Firmware:
|
||||
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
||||
|
||||
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
||||
raise ValueError("Firmware does not start with epatch signature")
|
||||
raise RtkFirmwareError("Firmware does not start with epatch signature")
|
||||
|
||||
if not firmware.endswith(extension_sig):
|
||||
raise ValueError("Firmware does not end with extension sig")
|
||||
raise RtkFirmwareError("Firmware does not end with extension sig")
|
||||
|
||||
# The firmware should start with a 14 byte header.
|
||||
epatch_header_size = 14
|
||||
if len(firmware) < epatch_header_size:
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Look for the "project ID", starting from the end.
|
||||
offset = len(firmware) - len(extension_sig)
|
||||
@@ -230,7 +235,7 @@ class Firmware:
|
||||
break
|
||||
|
||||
if length == 0:
|
||||
raise ValueError("Invalid 0-length instruction")
|
||||
raise RtkFirmwareError("Invalid 0-length instruction")
|
||||
|
||||
if opcode == 0 and length == 1:
|
||||
project_id = firmware[offset - 1]
|
||||
@@ -239,7 +244,7 @@ class Firmware:
|
||||
offset -= length
|
||||
|
||||
if project_id < 0:
|
||||
raise ValueError("Project ID not found")
|
||||
raise RtkFirmwareError("Project ID not found")
|
||||
|
||||
self.project_id = project_id
|
||||
|
||||
@@ -252,7 +257,7 @@ class Firmware:
|
||||
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
||||
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
||||
if epatch_header_size + 8 * num_patches > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
chip_id_table_offset = epatch_header_size
|
||||
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
||||
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
||||
@@ -266,7 +271,7 @@ class Firmware:
|
||||
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
||||
)
|
||||
if patch_offset + patch_length > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Get the SVN version for the patch
|
||||
(svn_version,) = struct.unpack_from(
|
||||
@@ -645,7 +650,7 @@ class Driver(common.Driver):
|
||||
):
|
||||
return await self.download_for_rtl8723b()
|
||||
|
||||
raise ValueError("ROM not supported")
|
||||
raise RtkFirmwareError("ROM not supported")
|
||||
|
||||
async def init_controller(self):
|
||||
await self.download_firmware()
|
||||
|
||||
@@ -331,9 +331,9 @@ class Client:
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
if mtu > 0xFFFF:
|
||||
raise ValueError('MTU must be <= 0xFFFF')
|
||||
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
|
||||
|
||||
# We can only send one request per connection
|
||||
if self.mtu_exchange_done:
|
||||
@@ -977,6 +977,7 @@ class Client:
|
||||
offset += len(part)
|
||||
|
||||
self.cache_value(attribute_handle, attribute_value)
|
||||
|
||||
# Return the value as bytes
|
||||
return attribute_value
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
AdvertisingData,
|
||||
DeviceClass,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
bit_flags_to_strings,
|
||||
name_or_number,
|
||||
@@ -92,14 +94,14 @@ def map_class_of_device(class_of_device):
|
||||
)
|
||||
|
||||
|
||||
def phy_list_to_bits(phys):
|
||||
def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int:
|
||||
if phys is None:
|
||||
return 0
|
||||
|
||||
phy_bits = 0
|
||||
for phy in phys:
|
||||
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
|
||||
raise ValueError('invalid PHY')
|
||||
raise InvalidArgumentError('invalid PHY')
|
||||
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
|
||||
return phy_bits
|
||||
|
||||
@@ -1553,7 +1555,7 @@ class HCI_Object:
|
||||
new_offset, field_value = field_type(data, offset)
|
||||
return (field_value, new_offset - offset)
|
||||
|
||||
raise ValueError(f'unknown field type {field_type}')
|
||||
raise InvalidArgumentError(f'unknown field type {field_type}')
|
||||
|
||||
@staticmethod
|
||||
def dict_from_bytes(data, offset, fields):
|
||||
@@ -1622,7 +1624,7 @@ class HCI_Object:
|
||||
if 0 <= field_value <= 255:
|
||||
field_bytes = bytes([field_value])
|
||||
else:
|
||||
raise ValueError('value too large for *-typed field')
|
||||
raise InvalidArgumentError('value too large for *-typed field')
|
||||
else:
|
||||
field_bytes = bytes(field_value)
|
||||
elif field_type == 'v':
|
||||
@@ -1641,7 +1643,9 @@ class HCI_Object:
|
||||
elif len(field_bytes) > field_type:
|
||||
field_bytes = field_bytes[:field_type]
|
||||
else:
|
||||
raise ValueError(f"don't know how to serialize type {type(field_value)}")
|
||||
raise InvalidArgumentError(
|
||||
f"don't know how to serialize type {type(field_value)}"
|
||||
)
|
||||
|
||||
return field_bytes
|
||||
|
||||
@@ -1835,6 +1839,12 @@ class Address:
|
||||
data, offset, Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_random_address(data, offset):
|
||||
return Address.parse_address_with_type(
|
||||
data, offset, Address.RANDOM_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_address_with_type(data, offset, address_type):
|
||||
return offset + 6, Address(data[offset : offset + 6], address_type)
|
||||
@@ -1905,7 +1915,7 @@ class Address:
|
||||
self.address_bytes = bytes(reversed(bytes.fromhex(address)))
|
||||
|
||||
if len(self.address_bytes) != 6:
|
||||
raise ValueError('invalid address length')
|
||||
raise InvalidArgumentError('invalid address length')
|
||||
|
||||
self.address_type = address_type
|
||||
|
||||
@@ -1961,7 +1971,8 @@ class Address:
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.address_bytes == other.address_bytes
|
||||
isinstance(other, Address)
|
||||
and self.address_bytes == other.address_bytes
|
||||
and self.is_public == other.is_public
|
||||
)
|
||||
|
||||
@@ -2108,7 +2119,7 @@ class HCI_Command(HCI_Packet):
|
||||
op_code, length = struct.unpack_from('<HB', packet, 1)
|
||||
parameters = packet[4:]
|
||||
if len(parameters) != length:
|
||||
raise ValueError('invalid packet length')
|
||||
raise InvalidPacketError('invalid packet length')
|
||||
|
||||
# Look for a registered class
|
||||
cls = HCI_Command.command_classes.get(op_code)
|
||||
@@ -4807,7 +4818,7 @@ class HCI_Event(HCI_Packet):
|
||||
length = packet[2]
|
||||
parameters = packet[3:]
|
||||
if len(parameters) != length:
|
||||
raise ValueError('invalid packet length')
|
||||
raise InvalidPacketError('invalid packet length')
|
||||
|
||||
cls: Any
|
||||
if event_code == HCI_LE_META_EVENT:
|
||||
@@ -5174,8 +5185,8 @@ class HCI_LE_Data_Length_Change_Event(HCI_LE_Meta_Event):
|
||||
),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('local_resolvable_private_address', Address.parse_address),
|
||||
('peer_resolvable_private_address', Address.parse_address),
|
||||
('local_resolvable_private_address', Address.parse_random_address),
|
||||
('peer_resolvable_private_address', Address.parse_random_address),
|
||||
('connection_interval', 2),
|
||||
('peripheral_latency', 2),
|
||||
('supervision_timeout', 2),
|
||||
@@ -6342,7 +6353,7 @@ class HCI_AclDataPacket(HCI_Packet):
|
||||
bc_flag = (h >> 14) & 3
|
||||
data = packet[5:]
|
||||
if len(data) != data_total_length:
|
||||
raise ValueError('invalid packet length')
|
||||
raise InvalidPacketError('invalid packet length')
|
||||
return HCI_AclDataPacket(
|
||||
connection_handle, pb_flag, bc_flag, data_total_length, data
|
||||
)
|
||||
@@ -6390,7 +6401,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
||||
packet_status = (h >> 12) & 0b11
|
||||
data = packet[4:]
|
||||
if len(data) != data_total_length:
|
||||
raise ValueError(
|
||||
raise InvalidPacketError(
|
||||
f'invalid packet length {len(data)} != {data_total_length}'
|
||||
)
|
||||
return HCI_SynchronousDataPacket(
|
||||
|
||||
@@ -772,6 +772,8 @@ class Host(AbortableEventEmitter):
|
||||
event.connection_handle,
|
||||
BT_LE_TRANSPORT,
|
||||
event.peer_address,
|
||||
getattr(event, 'local_resolvable_private_address', None),
|
||||
getattr(event, 'peer_resolvable_private_address', None),
|
||||
event.role,
|
||||
connection_parameters,
|
||||
)
|
||||
@@ -817,6 +819,8 @@ class Host(AbortableEventEmitter):
|
||||
event.bd_addr,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
|
||||
|
||||
@@ -41,7 +41,14 @@ from typing import (
|
||||
|
||||
from .utils import deprecated
|
||||
from .colors import color
|
||||
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
|
||||
from .core import (
|
||||
BT_CENTRAL_ROLE,
|
||||
InvalidStateError,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
OutOfResourcesError,
|
||||
ProtocolError,
|
||||
)
|
||||
from .hci import (
|
||||
HCI_LE_Connection_Update_Command,
|
||||
HCI_Object,
|
||||
@@ -189,17 +196,17 @@ class LeCreditBasedChannelSpec:
|
||||
self.max_credits < 1
|
||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
raise InvalidArgumentError('max credits out of range')
|
||||
if (
|
||||
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
||||
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
||||
):
|
||||
raise ValueError('MTU out of range')
|
||||
raise InvalidArgumentError('MTU out of range')
|
||||
if (
|
||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
):
|
||||
raise ValueError('MPS out of range')
|
||||
raise InvalidArgumentError('MPS out of range')
|
||||
|
||||
|
||||
class L2CAP_PDU:
|
||||
@@ -211,7 +218,7 @@ class L2CAP_PDU:
|
||||
def from_bytes(data: bytes) -> L2CAP_PDU:
|
||||
# Check parameters
|
||||
if len(data) < 4:
|
||||
raise ValueError('not enough data for L2CAP header')
|
||||
raise InvalidPacketError('not enough data for L2CAP header')
|
||||
|
||||
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
|
||||
l2cap_pdu_payload = data[4:]
|
||||
@@ -816,7 +823,7 @@ class ClassicChannel(EventEmitter):
|
||||
|
||||
# Check that we can start a new connection
|
||||
if self.connection_result:
|
||||
raise RuntimeError('connection already pending')
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||
self.send_control_frame(
|
||||
@@ -1129,7 +1136,7 @@ class LeCreditBasedChannel(EventEmitter):
|
||||
# Check that we can start a new connection
|
||||
identifier = self.manager.next_identifier(self.connection)
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise RuntimeError('too many concurrent connection requests')
|
||||
raise InvalidStateError('too many concurrent connection requests')
|
||||
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
@@ -1516,7 +1523,7 @@ class ChannelManager:
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
raise RuntimeError('no free CID available')
|
||||
raise OutOfResourcesError('no free CID available')
|
||||
|
||||
@staticmethod
|
||||
def find_free_le_cid(channels: Iterable[int]) -> int:
|
||||
@@ -1529,7 +1536,7 @@ class ChannelManager:
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
raise RuntimeError('no free CID')
|
||||
raise OutOfResourcesError('no free CID')
|
||||
|
||||
def next_identifier(self, connection: Connection) -> int:
|
||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||
@@ -1576,15 +1583,15 @@ class ChannelManager:
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if spec.psm in self.servers:
|
||||
raise ValueError('PSM already in use')
|
||||
raise InvalidArgumentError('PSM already in use')
|
||||
|
||||
# Check that the PSM is valid
|
||||
if spec.psm % 2 == 0:
|
||||
raise ValueError('invalid PSM (not odd)')
|
||||
raise InvalidArgumentError('invalid PSM (not odd)')
|
||||
check = spec.psm >> 8
|
||||
while check:
|
||||
if check % 2 != 0:
|
||||
raise ValueError('invalid PSM')
|
||||
raise InvalidArgumentError('invalid PSM')
|
||||
check >>= 8
|
||||
|
||||
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
|
||||
@@ -1626,7 +1633,7 @@ class ChannelManager:
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if spec.psm in self.le_coc_servers:
|
||||
raise ValueError('PSM already in use')
|
||||
raise InvalidArgumentError('PSM already in use')
|
||||
|
||||
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
|
||||
self,
|
||||
@@ -2154,10 +2161,10 @@ class ChannelManager:
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_le_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
raise OutOfResourcesError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise ValueError('PSM cannot be None')
|
||||
raise InvalidArgumentError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
|
||||
@@ -2206,10 +2213,10 @@ class ChannelManager:
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_br_edr_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
raise OutOfResourcesError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise ValueError('PSM cannot be None')
|
||||
raise InvalidArgumentError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(
|
||||
|
||||
@@ -19,7 +19,12 @@ import logging
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
|
||||
from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
InvalidStateError,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
@@ -405,12 +410,12 @@ class RemoteLink:
|
||||
|
||||
def add_controller(self, controller):
|
||||
if self.controller:
|
||||
raise ValueError('controller already set')
|
||||
raise InvalidStateError('controller already set')
|
||||
self.controller = controller
|
||||
|
||||
def remove_controller(self, controller):
|
||||
if self.controller != controller:
|
||||
raise ValueError('controller mismatch')
|
||||
raise InvalidStateError('controller mismatch')
|
||||
self.controller = None
|
||||
|
||||
def get_pending_connection(self):
|
||||
|
||||
@@ -685,10 +685,11 @@ class CodecSpecificConfiguration:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PacRecord:
|
||||
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
|
||||
|
||||
coding_format: hci.CodingFormat
|
||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||
# TODO: Parse Metadata
|
||||
metadata: bytes = b''
|
||||
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> PacRecord:
|
||||
@@ -701,7 +702,8 @@ class PacRecord:
|
||||
]
|
||||
offset += codec_specific_capabilities_size
|
||||
metadata_size = data[offset]
|
||||
metadata = data[offset : offset + metadata_size]
|
||||
offset += 1
|
||||
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
|
||||
|
||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||
@@ -719,12 +721,13 @@ class PacRecord:
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
return (
|
||||
bytes(self.coding_format)
|
||||
+ bytes([len(capabilities_bytes)])
|
||||
+ capabilities_bytes
|
||||
+ bytes([len(self.metadata)])
|
||||
+ self.metadata
|
||||
+ bytes([len(metadata_bytes)])
|
||||
+ metadata_bytes
|
||||
)
|
||||
|
||||
|
||||
@@ -940,8 +943,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
presentation_delay = 0
|
||||
|
||||
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
||||
# TODO: Parse this
|
||||
metadata = b''
|
||||
metadata = le_audio.Metadata()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1088,7 +1090,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.metadata = metadata
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
self.state = self.State.ENABLING
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
@@ -1140,7 +1142,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.metadata = metadata
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
@@ -1217,8 +1219,9 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.State.STREAMING,
|
||||
self.State.DISABLING,
|
||||
):
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
additional_parameters = (
|
||||
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
|
||||
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
|
||||
)
|
||||
else:
|
||||
additional_parameters = b''
|
||||
|
||||
@@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||
raise ValueError(
|
||||
raise core.InvalidArgumentError(
|
||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||
)
|
||||
|
||||
@@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
key = await connection.device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||
|
||||
@@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
raise RuntimeError('Invalid SIRK value')
|
||||
raise core.InvalidPacketError('Invalid SIRK value')
|
||||
|
||||
sirk_type = SirkType(response[0])
|
||||
if sirk_type == SirkType.PLAINTEXT:
|
||||
@@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
key = await device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk = sef(key, response[1:])
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from enum import IntEnum
|
||||
import struct
|
||||
|
||||
from bumble import core
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..att import ATT_Error
|
||||
from ..gatt import (
|
||||
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
|
||||
rr_intervals=None,
|
||||
):
|
||||
if heart_rate < 0 or heart_rate > 0xFFFF:
|
||||
raise ValueError('heart_rate out of range')
|
||||
raise core.InvalidArgumentError('heart_rate out of range')
|
||||
|
||||
if energy_expended is not None and (
|
||||
energy_expended < 0 or energy_expended > 0xFFFF
|
||||
):
|
||||
raise ValueError('energy_expended out of range')
|
||||
raise core.InvalidArgumentError('energy_expended out of range')
|
||||
|
||||
if rr_intervals:
|
||||
for rr_interval in rr_intervals:
|
||||
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
|
||||
raise ValueError('rr_intervals out of range')
|
||||
raise core.InvalidArgumentError('rr_intervals out of range')
|
||||
|
||||
self.heart_rate = heart_rate
|
||||
self.sensor_contact_detected = sensor_contact_detected
|
||||
|
||||
@@ -17,33 +17,67 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
from typing import List
|
||||
import struct
|
||||
from typing import List, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Metadata:
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
|
||||
|
||||
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
|
||||
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
|
||||
again outside the lib.
|
||||
'''
|
||||
|
||||
class Tag(utils.OpenIntEnum):
|
||||
# fmt: off
|
||||
PREFERRED_AUDIO_CONTEXTS = 0x01
|
||||
STREAMING_AUDIO_CONTEXTS = 0x02
|
||||
PROGRAM_INFO = 0x03
|
||||
LANGUAGE = 0x04
|
||||
CCID_LIST = 0x05
|
||||
PARENTAL_RATING = 0x06
|
||||
PROGRAM_INFO_URI = 0x07
|
||||
AUDIO_ACTIVE_STATE = 0x08
|
||||
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
|
||||
ASSISTED_LISTENING_STREAM = 0x0A
|
||||
BROADCAST_NAME = 0x0B
|
||||
EXTENDED_METADATA = 0xFE
|
||||
VENDOR_SPECIFIC = 0xFF
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Entry:
|
||||
tag: int
|
||||
tag: Metadata.Tag
|
||||
data: bytes
|
||||
|
||||
entries: List[Entry]
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([len(self.data) + 1, self.tag]) + self.data
|
||||
|
||||
entries: List[Entry] = dataclasses.field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
entries = []
|
||||
offset = 0
|
||||
length = len(data)
|
||||
while length >= 2:
|
||||
while offset < length:
|
||||
entry_length = data[offset]
|
||||
entry_tag = data[offset + 1]
|
||||
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
|
||||
entries.append(cls.Entry(entry_tag, entry_data))
|
||||
length -= entry_length
|
||||
offset += 1
|
||||
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
|
||||
offset += entry_length
|
||||
|
||||
return cls(entries)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return b''.join([bytes(entry) for entry in self.entries])
|
||||
|
||||
448
bumble/profiles/mcp.py
Normal file
448
bumble/profiles/mcp.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from bumble import core
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import utils
|
||||
|
||||
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PlayingOrder(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.15. Playing Order.'''
|
||||
|
||||
SINGLE_ONCE = 0x01
|
||||
SINGLE_REPEAT = 0x02
|
||||
IN_ORDER_ONCE = 0x03
|
||||
IN_ORDER_REPEAT = 0x04
|
||||
OLDEST_ONCE = 0x05
|
||||
OLDEST_REPEAT = 0x06
|
||||
NEWEST_ONCE = 0x07
|
||||
NEWEST_REPEAT = 0x08
|
||||
SHUFFLE_ONCE = 0x09
|
||||
SHUFFLE_REPEAT = 0x0A
|
||||
|
||||
|
||||
class PlayingOrderSupported(enum.IntFlag):
|
||||
'''See Media Control Service 3.16. Playing Orders Supported.'''
|
||||
|
||||
SINGLE_ONCE = 0x0001
|
||||
SINGLE_REPEAT = 0x0002
|
||||
IN_ORDER_ONCE = 0x0004
|
||||
IN_ORDER_REPEAT = 0x0008
|
||||
OLDEST_ONCE = 0x0010
|
||||
OLDEST_REPEAT = 0x0020
|
||||
NEWEST_ONCE = 0x0040
|
||||
NEWEST_REPEAT = 0x0080
|
||||
SHUFFLE_ONCE = 0x0100
|
||||
SHUFFLE_REPEAT = 0x0200
|
||||
|
||||
|
||||
class MediaState(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.17. Media State.'''
|
||||
|
||||
INACTIVE = 0x00
|
||||
PLAYING = 0x01
|
||||
PAUSED = 0x02
|
||||
SEEKING = 0x03
|
||||
|
||||
|
||||
class MediaControlPointOpcode(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.18. Media Control Point.'''
|
||||
|
||||
PLAY = 0x01
|
||||
PAUSE = 0x02
|
||||
FAST_REWIND = 0x03
|
||||
FAST_FORWARD = 0x04
|
||||
STOP = 0x05
|
||||
MOVE_RELATIVE = 0x10
|
||||
PREVIOUS_SEGMENT = 0x20
|
||||
NEXT_SEGMENT = 0x21
|
||||
FIRST_SEGMENT = 0x22
|
||||
LAST_SEGMENT = 0x23
|
||||
GOTO_SEGMENT = 0x24
|
||||
PREVIOUS_TRACK = 0x30
|
||||
NEXT_TRACK = 0x31
|
||||
FIRST_TRACK = 0x32
|
||||
LAST_TRACK = 0x33
|
||||
GOTO_TRACK = 0x34
|
||||
PREVIOUS_GROUP = 0x40
|
||||
NEXT_GROUP = 0x41
|
||||
FIRST_GROUP = 0x42
|
||||
LAST_GROUP = 0x43
|
||||
GOTO_GROUP = 0x44
|
||||
|
||||
|
||||
class MediaControlPointResultCode(enum.IntFlag):
|
||||
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
|
||||
|
||||
SUCCESS = 0x01
|
||||
OPCODE_NOT_SUPPORTED = 0x02
|
||||
MEDIA_PLAYER_INACTIVE = 0x03
|
||||
COMMAND_CANNOT_BE_COMPLETED = 0x04
|
||||
|
||||
|
||||
class MediaControlPointOpcodeSupported(enum.IntFlag):
|
||||
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
|
||||
|
||||
PLAY = 0x00000001
|
||||
PAUSE = 0x00000002
|
||||
FAST_REWIND = 0x00000004
|
||||
FAST_FORWARD = 0x00000008
|
||||
STOP = 0x00000010
|
||||
MOVE_RELATIVE = 0x00000020
|
||||
PREVIOUS_SEGMENT = 0x00000040
|
||||
NEXT_SEGMENT = 0x00000080
|
||||
FIRST_SEGMENT = 0x00000100
|
||||
LAST_SEGMENT = 0x00000200
|
||||
GOTO_SEGMENT = 0x00000400
|
||||
PREVIOUS_TRACK = 0x00000800
|
||||
NEXT_TRACK = 0x00001000
|
||||
FIRST_TRACK = 0x00002000
|
||||
LAST_TRACK = 0x00004000
|
||||
GOTO_TRACK = 0x00008000
|
||||
PREVIOUS_GROUP = 0x00010000
|
||||
NEXT_GROUP = 0x00020000
|
||||
FIRST_GROUP = 0x00040000
|
||||
LAST_GROUP = 0x00080000
|
||||
GOTO_GROUP = 0x00100000
|
||||
|
||||
|
||||
class SearchControlPointItemType(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.20. Search Control Point.'''
|
||||
|
||||
TRACK_NAME = 0x01
|
||||
ARTIST_NAME = 0x02
|
||||
ALBUM_NAME = 0x03
|
||||
GROUP_NAME = 0x04
|
||||
EARLIEST_YEAR = 0x05
|
||||
LATEST_YEAR = 0x06
|
||||
GENRE = 0x07
|
||||
ONLY_TRACKS = 0x08
|
||||
ONLY_GROUPS = 0x09
|
||||
|
||||
|
||||
class ObjectType(utils.OpenIntEnum):
|
||||
'''See Media Control Service 4.4.1. Object Type field.'''
|
||||
|
||||
TASK = 0
|
||||
GROUP = 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ObjectId(int):
|
||||
'''See Media Control Service 4.4.2. Object ID field.'''
|
||||
|
||||
@classmethod
|
||||
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data, byteorder='little', signed=False))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes(6, 'little')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GroupObjectType:
|
||||
'''See Media Control Service 4.4. Group Object Type.'''
|
||||
|
||||
object_type: ObjectType
|
||||
object_id: ObjectId
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(
|
||||
object_type=ObjectType(data[0]),
|
||||
object_id=ObjectId.create_from_bytes(data[1:]),
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.object_type]) + bytes(self.object_id)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaControlService(gatt.TemplateService):
|
||||
'''Media Control Service server implementation, only for testing currently.'''
|
||||
|
||||
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
|
||||
|
||||
def __init__(self, media_player_name: Optional[str] = None) -> None:
|
||||
self.track_position = 0
|
||||
|
||||
self.media_player_name_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=media_player_name or 'Bumble Player',
|
||||
)
|
||||
self.track_changed_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_title_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_duration_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_position_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.media_state_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.media_control_point_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self.on_media_control_point),
|
||||
)
|
||||
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.content_control_id_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.media_player_name_characteristic,
|
||||
self.track_changed_characteristic,
|
||||
self.track_title_characteristic,
|
||||
self.track_duration_characteristic,
|
||||
self.track_position_characteristic,
|
||||
self.media_state_characteristic,
|
||||
self.media_control_point_characteristic,
|
||||
self.media_control_point_opcodes_supported_characteristic,
|
||||
self.content_control_id_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
async def on_media_control_point(
|
||||
self, connection: Optional[device.Connection], data: bytes
|
||||
) -> None:
|
||||
if not connection:
|
||||
raise core.InvalidStateError()
|
||||
|
||||
opcode = MediaControlPointOpcode(data[0])
|
||||
|
||||
await connection.device.notify_subscriber(
|
||||
connection,
|
||||
self.media_control_point_characteristic,
|
||||
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
|
||||
)
|
||||
|
||||
|
||||
class GenericMediaControlService(MediaControlService):
|
||||
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaControlServiceProxy(
|
||||
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
|
||||
):
|
||||
SERVICE_CLASS = MediaControlService
|
||||
|
||||
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
|
||||
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
|
||||
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
|
||||
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
|
||||
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
|
||||
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
|
||||
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
|
||||
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
|
||||
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
|
||||
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
|
||||
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||
}
|
||||
|
||||
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_changed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_title: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_duration: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_position: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_order: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_state: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
|
||||
None
|
||||
)
|
||||
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
media_control_point_notifications: asyncio.Queue[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
utils.CompositeEventEmitter.__init__(self)
|
||||
self.service_proxy = service_proxy
|
||||
self.lock = asyncio.Lock()
|
||||
self.media_control_point_notifications = asyncio.Queue()
|
||||
|
||||
for field, uuid in self._CHARACTERISTICS.items():
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
|
||||
setattr(self, field, characteristics[0])
|
||||
|
||||
async def subscribe_characteristics(self) -> None:
|
||||
if self.media_control_point:
|
||||
await self.media_control_point.subscribe(self._on_media_control_point)
|
||||
if self.media_state:
|
||||
await self.media_state.subscribe(self._on_media_state)
|
||||
if self.track_changed:
|
||||
await self.track_changed.subscribe(self._on_track_changed)
|
||||
if self.track_title:
|
||||
await self.track_title.subscribe(self._on_track_title)
|
||||
if self.track_duration:
|
||||
await self.track_duration.subscribe(self._on_track_duration)
|
||||
if self.track_position:
|
||||
await self.track_position.subscribe(self._on_track_position)
|
||||
|
||||
async def write_control_point(
|
||||
self, opcode: MediaControlPointOpcode
|
||||
) -> MediaControlPointResultCode:
|
||||
'''Writes a Media Control Point Opcode to peer and waits for the notification.
|
||||
|
||||
The write operation will be executed when there isn't other pending commands.
|
||||
|
||||
Args:
|
||||
opcode: opcode defined in `MediaControlPointOpcode`.
|
||||
|
||||
Returns:
|
||||
Response code provided in `MediaControlPointResultCode`
|
||||
|
||||
Raises:
|
||||
InvalidOperationError: Server does not have Media Control Point Characteristic.
|
||||
InvalidStateError: Server replies a notification with mismatched opcode.
|
||||
'''
|
||||
if not self.media_control_point:
|
||||
raise core.InvalidOperationError("Peer does not have media control point")
|
||||
|
||||
async with self.lock:
|
||||
await self.media_control_point.write_value(
|
||||
bytes([opcode]),
|
||||
with_response=False,
|
||||
)
|
||||
|
||||
(
|
||||
response_opcode,
|
||||
response_code,
|
||||
) = await self.media_control_point_notifications.get()
|
||||
if response_opcode != opcode:
|
||||
raise core.InvalidStateError(
|
||||
f"Expected {opcode} notification, but get {response_opcode}"
|
||||
)
|
||||
return MediaControlPointResultCode(response_code)
|
||||
|
||||
def _on_media_control_point(self, data: bytes) -> None:
|
||||
self.media_control_point_notifications.put_nowait(data)
|
||||
|
||||
def _on_media_state(self, data: bytes) -> None:
|
||||
self.emit('media_state', MediaState(data[0]))
|
||||
|
||||
def _on_track_changed(self, data: bytes) -> None:
|
||||
del data
|
||||
self.emit('track_changed')
|
||||
|
||||
def _on_track_title(self, data: bytes) -> None:
|
||||
self.emit('track_title', data.decode("utf-8"))
|
||||
|
||||
def _on_track_duration(self, data: bytes) -> None:
|
||||
self.emit('track_duration', struct.unpack_from('<i', data)[0])
|
||||
|
||||
def _on_track_position(self, data: bytes) -> None:
|
||||
self.emit('track_position', struct.unpack_from('<i', data)[0])
|
||||
|
||||
|
||||
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
|
||||
SERVICE_CLASS = GenericMediaControlService
|
||||
@@ -36,7 +36,9 @@ from .core import (
|
||||
BT_RFCOMM_PROTOCOL_ID,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
InvalidArgumentError,
|
||||
InvalidStateError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
)
|
||||
|
||||
@@ -335,7 +337,7 @@ class RFCOMM_Frame:
|
||||
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
|
||||
if frame.fcs != fcs:
|
||||
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
|
||||
raise ValueError('fcs mismatch')
|
||||
raise InvalidPacketError('fcs mismatch')
|
||||
|
||||
return frame
|
||||
|
||||
@@ -713,7 +715,7 @@ class DLC(EventEmitter):
|
||||
# Automatically convert strings to bytes using UTF-8
|
||||
data = data.encode('utf-8')
|
||||
else:
|
||||
raise ValueError('write only accept bytes or strings')
|
||||
raise InvalidArgumentError('write only accept bytes or strings')
|
||||
|
||||
self.tx_buffer += data
|
||||
self.drained.clear()
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing_extensions import Self
|
||||
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import InvalidStateError
|
||||
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -189,7 +189,9 @@ class DataElement:
|
||||
self.bytes = None
|
||||
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||
if value_size is None:
|
||||
raise ValueError('integer types must have a value size specified')
|
||||
raise InvalidArgumentError(
|
||||
'integer types must have a value size specified'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def nil() -> DataElement:
|
||||
@@ -265,7 +267,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>Q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_from_bytes(data):
|
||||
@@ -281,7 +283,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def list_from_bytes(data):
|
||||
@@ -354,7 +356,7 @@ class DataElement:
|
||||
data = b''
|
||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise ValueError('UNSIGNED_INTEGER cannot be negative')
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('B', self.value)
|
||||
@@ -365,7 +367,7 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.SIGNED_INTEGER:
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('b', self.value)
|
||||
@@ -376,7 +378,7 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type == DataElement.URL:
|
||||
@@ -392,7 +394,7 @@ class DataElement:
|
||||
size_bytes = b''
|
||||
if self.type == DataElement.NIL:
|
||||
if size != 0:
|
||||
raise ValueError('NIL must be empty')
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif self.type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
@@ -410,7 +412,7 @@ class DataElement:
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type in (
|
||||
DataElement.TEXT_STRING,
|
||||
DataElement.SEQUENCE,
|
||||
@@ -427,10 +429,10 @@ class DataElement:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise ValueError('boolean must be 1 byte')
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
|
||||
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
|
||||
@@ -55,6 +55,7 @@ from .core import (
|
||||
BT_CENTRAL_ROLE,
|
||||
BT_LE_TRANSPORT,
|
||||
AdvertisingData,
|
||||
InvalidArgumentError,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
@@ -766,8 +767,11 @@ class Session:
|
||||
self.oob_data_flag = 0 if pairing_config.oob is None else 1
|
||||
|
||||
# Set up addresses
|
||||
self_address = connection.self_address
|
||||
self_address = connection.self_resolvable_address or connection.self_address
|
||||
peer_address = connection.peer_resolvable_address or connection.peer_address
|
||||
logger.debug(
|
||||
f"pairing with self_address={self_address}, peer_address={peer_address}"
|
||||
)
|
||||
if self.is_initiator:
|
||||
self.ia = bytes(self_address)
|
||||
self.iat = 1 if self_address.is_random else 0
|
||||
@@ -784,7 +788,7 @@ class Session:
|
||||
self.peer_oob_data = pairing_config.oob.peer_data
|
||||
if pairing_config.sc:
|
||||
if pairing_config.oob.our_context is None:
|
||||
raise ValueError(
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a context when sc is True"
|
||||
)
|
||||
self.r = pairing_config.oob.our_context.r
|
||||
@@ -793,7 +797,7 @@ class Session:
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
if pairing_config.oob.legacy_context is None:
|
||||
raise ValueError(
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a legacy context when sc is False"
|
||||
)
|
||||
self.r = bytes(16)
|
||||
@@ -1075,9 +1079,9 @@ class Session:
|
||||
|
||||
def send_identity_address_command(self) -> None:
|
||||
identity_address = {
|
||||
None: self.connection.self_address,
|
||||
None: self.manager.device.static_address,
|
||||
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
|
||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
|
||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
|
||||
}[self.pairing_config.identity_address_type]
|
||||
self.send_command(
|
||||
SMP_Identity_Address_Information_Command(
|
||||
|
||||
@@ -23,6 +23,7 @@ import datetime
|
||||
from typing import BinaryIO, Generator
|
||||
import os
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||
|
||||
|
||||
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
|
||||
"""
|
||||
if ':' not in spec:
|
||||
raise ValueError('snooper type prefix missing')
|
||||
raise core.InvalidArgumentError('snooper type prefix missing')
|
||||
|
||||
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
||||
|
||||
if snooper_type == 'btsnoop':
|
||||
if ':' not in snooper_args:
|
||||
raise ValueError('I/O type for btsnoop snooper type missing')
|
||||
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
|
||||
|
||||
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
||||
if io_type == 'file':
|
||||
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
_SNOOPER_INSTANCE_COUNT -= 1
|
||||
return
|
||||
|
||||
raise ValueError(f'I/O type {io_type} not supported')
|
||||
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
|
||||
|
||||
raise ValueError(f'snooper type {snooper_type} not found')
|
||||
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
|
||||
from ..snoop import create_snooper
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -180,7 +180,13 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
if scheme == 'unix':
|
||||
from .unix import open_unix_client_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_client_transport(spec)
|
||||
|
||||
raise TransportSpecError('unknown transport scheme')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -20,7 +20,13 @@ import grpc.aio
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
||||
from .common import (
|
||||
PumpedTransport,
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
||||
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
elif ':' in param:
|
||||
server_host, server_port = param.split(':')
|
||||
else:
|
||||
raise ValueError('invalid parameter')
|
||||
raise TransportSpecError('invalid parameter')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
service = VhciForwardingServiceStub(channel)
|
||||
hci_device = HciDevice(service.attachVhci())
|
||||
else:
|
||||
raise ValueError('invalid mode')
|
||||
raise TransportSpecError('invalid mode')
|
||||
|
||||
# Create the transport object
|
||||
class EmulatorTransport(PumpedTransport):
|
||||
|
||||
@@ -31,6 +31,8 @@ from .common import (
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
TransportInitError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise ValueError('invalid port')
|
||||
raise TransportSpecError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
|
||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||
server_port = find_grpc_port(instance_number)
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
raise TransportInitError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
|
||||
|
||||
if response_type == 'error':
|
||||
logger.warning(f'received error: {response.error}')
|
||||
raise RuntimeError(response.error)
|
||||
raise TransportInitError(response.error)
|
||||
|
||||
if response_type == 'hci_packet':
|
||||
return (
|
||||
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
|
||||
+ response.hci_packet.packet
|
||||
)
|
||||
|
||||
raise ValueError('unsupported response type')
|
||||
raise TransportSpecError('unsupported response type')
|
||||
|
||||
async def write(self, packet):
|
||||
await self.hci_device.write(
|
||||
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
options: Dict[str, str] = {}
|
||||
for param in params[params_offset:]:
|
||||
if '=' not in param:
|
||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
||||
raise TransportSpecError('invalid parameter, expected <name>=<value>')
|
||||
option_name, option_value = param.split('=')
|
||||
options[option_name] = option_value
|
||||
|
||||
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
)
|
||||
if mode == 'controller':
|
||||
if host is None:
|
||||
raise ValueError('<host>:<port> missing')
|
||||
raise TransportSpecError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port, options)
|
||||
|
||||
raise ValueError('invalid mode option')
|
||||
raise TransportSpecError('invalid mode option')
|
||||
|
||||
@@ -23,6 +23,7 @@ import logging
|
||||
import io
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.snoop import Snooper
|
||||
@@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
||||
# -----------------------------------------------------------------------------
|
||||
# Errors
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportLostError(Exception):
|
||||
"""
|
||||
The Transport has been lost/disconnected.
|
||||
"""
|
||||
class TransportLostError(core.BaseBumbleError, RuntimeError):
|
||||
"""The Transport has been lost/disconnected."""
|
||||
|
||||
|
||||
class TransportInitError(core.BaseBumbleError, RuntimeError):
|
||||
"""Error raised when the transport cannot be initialized."""
|
||||
|
||||
|
||||
class TransportSpecError(core.BaseBumbleError, ValueError):
|
||||
"""Error raised when the transport spec is invalid."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -132,7 +139,9 @@ class PacketParser:
|
||||
packet_type
|
||||
) or self.extended_packet_info.get(packet_type)
|
||||
if self.packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type}')
|
||||
raise core.InvalidPacketError(
|
||||
f'invalid packet type {packet_type}'
|
||||
)
|
||||
self.state = PacketParser.NEED_LENGTH
|
||||
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
||||
elif self.state == PacketParser.NEED_LENGTH:
|
||||
@@ -178,19 +187,19 @@ class PacketReader:
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
header = self.source.read(header_size)
|
||||
if len(header) != header_size:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
# Read the body
|
||||
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
||||
body = self.source.read(body_length)
|
||||
if len(body) != body_length:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
return packet_type + header + body
|
||||
|
||||
@@ -211,7 +220,7 @@ class AsyncPacketReader:
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
@@ -420,7 +429,7 @@ class SnoopingTransport(Transport):
|
||||
return SnoopingTransport(
|
||||
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
||||
)
|
||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
||||
raise core.UnreachableError() # Satisfy the type checker
|
||||
|
||||
class Source:
|
||||
sink: TransportSink
|
||||
|
||||
@@ -29,7 +29,7 @@ from usb.core import USBError
|
||||
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
|
||||
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
|
||||
|
||||
from .common import Transport, ParserSource
|
||||
from .common import Transport, ParserSource, TransportInitError
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
|
||||
@@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
device = None
|
||||
|
||||
if device is None:
|
||||
raise ValueError('device not found')
|
||||
raise TransportInitError('device not found')
|
||||
logger.debug(f'USB Device: {device}')
|
||||
|
||||
# Power Cycle the device
|
||||
|
||||
56
bumble/transport/unix.py
Normal file
56
bumble/transport/unix.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from .common import Transport, StreamPacketSource, StreamPacketSink
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_unix_client_transport(spec: str) -> Transport:
|
||||
'''Open a UNIX socket client transport.
|
||||
|
||||
The parameter is the path of unix socket. For abstract socket, the first character
|
||||
needs to be '@'.
|
||||
|
||||
Example:
|
||||
* /tmp/hci.socket
|
||||
* @hci_socket
|
||||
'''
|
||||
|
||||
class UnixPacketSource(StreamPacketSource):
|
||||
def connection_lost(self, exc):
|
||||
logger.debug(f'connection lost: {exc}')
|
||||
self.on_transport_lost()
|
||||
|
||||
# For abstract socket, the first character should be null character.
|
||||
if spec.startswith('@'):
|
||||
spec = '\0' + spec[1:]
|
||||
|
||||
(
|
||||
unix_transport,
|
||||
packet_source,
|
||||
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
|
||||
packet_sink = StreamPacketSink(unix_transport)
|
||||
|
||||
return Transport(packet_source, packet_sink)
|
||||
@@ -15,19 +15,18 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import collections
|
||||
import ctypes
|
||||
import platform
|
||||
|
||||
import usb1
|
||||
|
||||
from bumble.transport.common import Transport, ParserSource
|
||||
from bumble.transport.common import Transport, ParserSource, TransportInitError
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -115,13 +114,17 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
self.device = device
|
||||
self.acl_out = acl_out
|
||||
self.acl_out_transfer = device.getTransfer()
|
||||
self.packets = collections.deque() # Queue of packets waiting to be sent
|
||||
self.acl_out_transfer_ready = asyncio.Semaphore(1)
|
||||
self.packets: asyncio.Queue[bytes] = (
|
||||
asyncio.Queue()
|
||||
) # Queue of packets waiting to be sent
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue_task = None
|
||||
self.cancel_done = self.loop.create_future()
|
||||
self.closed = False
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
self.queue_task = asyncio.create_task(self.process_queue())
|
||||
|
||||
def on_packet(self, packet):
|
||||
# Ignore packets if we're closed
|
||||
@@ -133,62 +136,64 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
return
|
||||
|
||||
# Queue the packet
|
||||
self.packets.append(packet)
|
||||
if len(self.packets) == 1:
|
||||
# The queue was previously empty, re-prime the pump
|
||||
self.process_queue()
|
||||
self.packets.put_nowait(packet)
|
||||
|
||||
def transfer_callback(self, transfer):
|
||||
self.acl_out_transfer_ready.release()
|
||||
status = transfer.getStatus()
|
||||
|
||||
# pylint: disable=no-member
|
||||
if status == usb1.TRANSFER_COMPLETED:
|
||||
self.loop.call_soon_threadsafe(self.on_packet_sent)
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
if status == usb1.TRANSFER_CANCELLED:
|
||||
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
|
||||
else:
|
||||
return
|
||||
|
||||
if status != usb1.TRANSFER_COMPLETED:
|
||||
logger.warning(
|
||||
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
||||
)
|
||||
|
||||
def on_packet_sent(self):
|
||||
if self.packets:
|
||||
self.packets.popleft()
|
||||
self.process_queue()
|
||||
async def process_queue(self):
|
||||
while True:
|
||||
# Wait for a packet to transfer.
|
||||
packet = await self.packets.get()
|
||||
|
||||
def process_queue(self):
|
||||
if len(self.packets) == 0:
|
||||
return # Nothing to do
|
||||
# Wait until we can start a transfer.
|
||||
await self.acl_out_transfer_ready.acquire()
|
||||
|
||||
packet = self.packets[0]
|
||||
packet_type = packet[0]
|
||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.acl_out_transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.transfer_callback
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.acl_out_transfer.setControl(
|
||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
packet[1:],
|
||||
callback=self.transfer_callback,
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
else:
|
||||
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
|
||||
# Transfer the packet.
|
||||
packet_type = packet[0]
|
||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.acl_out_transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.transfer_callback
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.acl_out_transfer.setControl(
|
||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
packet[1:],
|
||||
callback=self.transfer_callback,
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'unsupported packet type {packet_type}', 'red')
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
if self.queue_task:
|
||||
self.queue_task.cancel()
|
||||
|
||||
async def terminate(self):
|
||||
if not self.closed:
|
||||
self.close()
|
||||
|
||||
# Empty the packet queue so that we don't send any more data
|
||||
self.packets.clear()
|
||||
while not self.packets.empty():
|
||||
self.packets.get_nowait()
|
||||
|
||||
# If we have a transfer in flight, cancel it
|
||||
if self.acl_out_transfer.isSubmitted():
|
||||
@@ -442,7 +447,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
|
||||
if found is None:
|
||||
context.close()
|
||||
raise ValueError('device not found')
|
||||
raise TransportInitError('device not found')
|
||||
|
||||
logger.debug(f'USB Device: {found}')
|
||||
|
||||
@@ -507,7 +512,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
|
||||
endpoints = find_endpoints(found)
|
||||
if endpoints is None:
|
||||
raise ValueError('no compatible interface found for device')
|
||||
raise TransportInitError('no compatible interface found for device')
|
||||
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
|
||||
logger.debug(
|
||||
f'selected endpoints: configuration={configuration}, '
|
||||
|
||||
BIN
docs/images/favicon.ico
Normal file
BIN
docs/images/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
7
examples/device_with_rpa.json
Normal file
7
examples/device_with_rpa.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"name": "Bumble",
|
||||
"address": "F0:F1:F2:F3:F4:F5",
|
||||
"keystore": "JsonKeyStore",
|
||||
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
|
||||
"le_privacy_enabled": true
|
||||
}
|
||||
@@ -3,5 +3,6 @@
|
||||
"keystore": "JsonKeyStore",
|
||||
"address": "F0:F1:F2:F3:F4:FA",
|
||||
"class_of_device": 2376708,
|
||||
"cis_enabled": true,
|
||||
"advertising_interval": 100
|
||||
}
|
||||
|
||||
83
examples/mcp_server.html
Normal file
83
examples/mcp_server.html
Normal file
@@ -0,0 +1,83 @@
|
||||
<html data-bs-theme="dark">
|
||||
|
||||
<head>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
|
||||
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<nav class="navbar navbar-dark bg-primary">
|
||||
<div class="container">
|
||||
<span class="navbar-brand mb-0 h1">Bumble LEA Media Control Client</span>
|
||||
</div>
|
||||
</nav>
|
||||
<br>
|
||||
|
||||
<div class="container">
|
||||
|
||||
<label class="form-label">Server Port</label>
|
||||
<div class="input-group mb-3">
|
||||
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
|
||||
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
|
||||
</div>
|
||||
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x01)">Play</button>
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x02)">Pause</button>
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x03)">Fast Rewind</button>
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x04)">Fast Forward</button>
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x05)">Stop</button>
|
||||
|
||||
</br></br>
|
||||
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x30)">Previous Track</button>
|
||||
<button class="btn btn-primary" onclick="send_opcode(0x31)">Next Track</button>
|
||||
|
||||
<hr>
|
||||
|
||||
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
|
||||
<h3>Log</h3>
|
||||
<code id="log" style="white-space: pre-line;"></code>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<script>
|
||||
let portInput = document.getElementById("port")
|
||||
let log = document.getElementById("log")
|
||||
let socket
|
||||
|
||||
function connect() {
|
||||
socket = new WebSocket(`ws://localhost:${portInput.value}`);
|
||||
socket.onopen = _ => {
|
||||
log.textContent += 'OPEN\n'
|
||||
}
|
||||
socket.onclose = _ => {
|
||||
log.textContent += 'CLOSED\n'
|
||||
}
|
||||
socket.onerror = (error) => {
|
||||
log.textContent += 'ERROR\n'
|
||||
console.log(`ERROR: ${error}`)
|
||||
}
|
||||
socket.onmessage = (event) => {
|
||||
log.textContent += `<-- ${event.data}\n`
|
||||
}
|
||||
}
|
||||
|
||||
function send(message) {
|
||||
if (socket && socket.readyState == WebSocket.OPEN) {
|
||||
let jsonMessage = JSON.stringify(message)
|
||||
log.textContent += `--> ${jsonMessage}\n`
|
||||
socket.send(jsonMessage)
|
||||
} else {
|
||||
log.textContent += 'NOT CONNECTED\n'
|
||||
}
|
||||
}
|
||||
|
||||
function send_opcode(opcode) {
|
||||
send({ 'opcode': opcode })
|
||||
}
|
||||
</script>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
196
examples/run_mcp_client.py
Normal file
196
examples/run_mcp_client.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import websockets
|
||||
import json
|
||||
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import (
|
||||
Device,
|
||||
AdvertisingParameters,
|
||||
AdvertisingEventProperties,
|
||||
Connection,
|
||||
Peer,
|
||||
)
|
||||
from bumble.hci import (
|
||||
CodecID,
|
||||
CodingFormat,
|
||||
OwnAddressType,
|
||||
)
|
||||
from bumble.profiles.bap import (
|
||||
CodecSpecificCapabilities,
|
||||
ContextType,
|
||||
AudioLocation,
|
||||
SupportedSamplingFrequency,
|
||||
SupportedFrameDuration,
|
||||
PacRecord,
|
||||
PublishedAudioCapabilitiesService,
|
||||
AudioStreamControlService,
|
||||
UnicastServerAdvertisingData,
|
||||
)
|
||||
from bumble.profiles.mcp import (
|
||||
MediaControlServiceProxy,
|
||||
GenericMediaControlServiceProxy,
|
||||
MediaState,
|
||||
MediaControlPointOpcode,
|
||||
)
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: run_mcp_client.py <config-file>' '<transport-spec-for-device>')
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
|
||||
print('<<< connected')
|
||||
|
||||
device = Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
# Add "placeholder" services to enable Android LEA features.
|
||||
device.add_service(
|
||||
PublishedAudioCapabilitiesService(
|
||||
supported_source_context=ContextType.PROHIBITED,
|
||||
available_source_context=ContextType.PROHIBITED,
|
||||
supported_sink_context=ContextType.MEDIA,
|
||||
available_sink_context=ContextType.MEDIA,
|
||||
sink_audio_locations=(
|
||||
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
sink_pac=[
|
||||
PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
SupportedSamplingFrequency.FREQ_16000
|
||||
| SupportedSamplingFrequency.FREQ_32000
|
||||
| SupportedSamplingFrequency.FREQ_48000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_count=[1, 2],
|
||||
min_octets_per_codec_frame=0,
|
||||
max_octets_per_codec_frame=320,
|
||||
supported_max_codec_frames_per_sdu=2,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
device.add_service(AudioStreamControlService(device, sink_ase_id=[1]))
|
||||
|
||||
ws: Optional[websockets.WebSocketServerProtocol] = None
|
||||
mcp: Optional[MediaControlServiceProxy] = None
|
||||
|
||||
advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes('Bumble LE Audio', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
) + bytes(UnicastServerAdvertisingData())
|
||||
|
||||
await device.create_advertising_set(
|
||||
advertising_parameters=AdvertisingParameters(
|
||||
advertising_event_properties=AdvertisingEventProperties(),
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
primary_advertising_interval_max=100,
|
||||
primary_advertising_interval_min=100,
|
||||
),
|
||||
advertising_data=advertising_data,
|
||||
auto_restart=True,
|
||||
)
|
||||
|
||||
def on_media_state(media_state: MediaState) -> None:
|
||||
if ws:
|
||||
asyncio.create_task(
|
||||
ws.send(json.dumps({'media_state': media_state.name}))
|
||||
)
|
||||
|
||||
def on_track_title(title: str) -> None:
|
||||
if ws:
|
||||
asyncio.create_task(ws.send(json.dumps({'title': title})))
|
||||
|
||||
def on_track_duration(duration: int) -> None:
|
||||
if ws:
|
||||
asyncio.create_task(ws.send(json.dumps({'duration': duration})))
|
||||
|
||||
def on_track_position(position: int) -> None:
|
||||
if ws:
|
||||
asyncio.create_task(ws.send(json.dumps({'position': position})))
|
||||
|
||||
def on_connection(connection: Connection) -> None:
|
||||
async def on_connection_async():
|
||||
async with Peer(connection) as peer:
|
||||
nonlocal mcp
|
||||
mcp = peer.create_service_proxy(MediaControlServiceProxy)
|
||||
if not mcp:
|
||||
mcp = peer.create_service_proxy(GenericMediaControlServiceProxy)
|
||||
mcp.on('media_state', on_media_state)
|
||||
mcp.on('track_title', on_track_title)
|
||||
mcp.on('track_duration', on_track_duration)
|
||||
mcp.on('track_position', on_track_position)
|
||||
await mcp.subscribe_characteristics()
|
||||
|
||||
connection.abort_on('disconnection', on_connection_async())
|
||||
|
||||
device.on('connection', on_connection)
|
||||
|
||||
async def serve(websocket: websockets.WebSocketServerProtocol, _path):
|
||||
nonlocal ws
|
||||
ws = websocket
|
||||
async for message in websocket:
|
||||
request = json.loads(message)
|
||||
if mcp:
|
||||
await mcp.write_control_point(
|
||||
MediaControlPointOpcode(request['opcode'])
|
||||
)
|
||||
ws = None
|
||||
|
||||
await websockets.serve(serve, 'localhost', 8989)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -48,6 +48,7 @@ from bumble.profiles.bap import (
|
||||
PublishedAudioCapabilitiesService,
|
||||
PublishedAudioCapabilitiesServiceProxy,
|
||||
)
|
||||
from bumble.profiles.le_audio import Metadata
|
||||
from tests.test_utils import TwoDevices
|
||||
|
||||
|
||||
@@ -97,7 +98,7 @@ def test_pac_record() -> None:
|
||||
pac_record = PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=cap,
|
||||
metadata=b'',
|
||||
metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
|
||||
)
|
||||
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
|
||||
|
||||
@@ -142,7 +143,7 @@ def test_ASE_Config_QOS() -> None:
|
||||
def test_ASE_Enable() -> None:
|
||||
operation = ASE_Enable(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
metadata=[b'', b''],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
@@ -151,7 +152,7 @@ def test_ASE_Enable() -> None:
|
||||
def test_ASE_Update_Metadata() -> None:
|
||||
operation = ASE_Update_Metadata(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
metadata=[b'', b''],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
@@ -276,34 +276,6 @@ async def test_legacy_advertising():
|
||||
assert not device.is_advertising
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising_connection(own_address_type):
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
|
||||
# Start advertising
|
||||
await device.start_advertising()
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
if own_address_type == OwnAddressType.PUBLIC:
|
||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
||||
else:
|
||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
||||
|
||||
await async_barrier()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
@@ -318,6 +290,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
None,
|
||||
None,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
@@ -367,6 +341,8 @@ async def test_extended_advertising_connection(own_address_type):
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
None,
|
||||
None,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
@@ -407,6 +383,8 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
None,
|
||||
None,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
@@ -879,6 +879,57 @@ async def test_unsubscribe():
|
||||
mock1.assert_called_once_with(ANY, False, False)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_all():
|
||||
[client, server] = LinkedDevices().devices[:2]
|
||||
|
||||
characteristic1 = Characteristic(
|
||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([1, 2, 3]),
|
||||
)
|
||||
|
||||
descriptor1 = Descriptor('2902', 'READABLE,WRITEABLE')
|
||||
descriptor2 = Descriptor('AAAA', 'READABLE,WRITEABLE')
|
||||
characteristic2 = Characteristic(
|
||||
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([1, 2, 3]),
|
||||
descriptors=[descriptor1, descriptor2],
|
||||
)
|
||||
|
||||
service1 = Service(
|
||||
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
|
||||
[characteristic1, characteristic2],
|
||||
)
|
||||
service2 = Service('1111', [])
|
||||
server.add_services([service1, service2])
|
||||
|
||||
await client.power_on()
|
||||
await server.power_on()
|
||||
connection = await client.connect(server.random_address)
|
||||
peer = Peer(connection)
|
||||
|
||||
await peer.discover_all()
|
||||
assert len(peer.gatt_client.services) == 3
|
||||
# service 1800 gets added automatically
|
||||
assert peer.gatt_client.services[0].uuid == UUID('1800')
|
||||
assert peer.gatt_client.services[1].uuid == service1.uuid
|
||||
assert peer.gatt_client.services[2].uuid == service2.uuid
|
||||
s = peer.get_services_by_uuid(service1.uuid)
|
||||
assert len(s) == 1
|
||||
assert len(s[0].characteristics) == 2
|
||||
c = peer.get_characteristics_by_uuid(uuid=characteristic2.uuid, service=s[0])
|
||||
assert len(c) == 1
|
||||
assert len(c[0].descriptors) == 2
|
||||
s = peer.get_services_by_uuid(service2.uuid)
|
||||
assert len(s) == 1
|
||||
assert len(s[0].characteristics) == 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_mtu_exchange():
|
||||
@@ -1146,6 +1197,56 @@ def test_get_attribute_group():
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_characteristics_by_uuid():
|
||||
[client, server] = LinkedDevices().devices[:2]
|
||||
|
||||
characteristic1 = Characteristic(
|
||||
'1234',
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([1, 2, 3]),
|
||||
)
|
||||
characteristic2 = Characteristic(
|
||||
'5678',
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([1, 2, 3]),
|
||||
)
|
||||
service1 = Service(
|
||||
'ABCD',
|
||||
[characteristic1, characteristic2],
|
||||
)
|
||||
service2 = Service(
|
||||
'FFFF',
|
||||
[characteristic1],
|
||||
)
|
||||
|
||||
server.add_services([service1, service2])
|
||||
|
||||
await client.power_on()
|
||||
await server.power_on()
|
||||
connection = await client.connect(server.random_address)
|
||||
peer = Peer(connection)
|
||||
|
||||
await peer.discover_services()
|
||||
await peer.discover_characteristics()
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
|
||||
assert len(c) == 2
|
||||
assert isinstance(c[0], CharacteristicProxy)
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
|
||||
assert len(c) == 1
|
||||
assert isinstance(c[0], CharacteristicProxy)
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
|
||||
assert len(c) == 0
|
||||
|
||||
s = peer.get_services_by_uuid(uuid=UUID('ABCD'))
|
||||
assert len(s) == 1
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=s[0])
|
||||
assert len(s) == 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
|
||||
39
tests/le_audio_test.py
Normal file
39
tests/le_audio_test.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from bumble.profiles import le_audio
|
||||
|
||||
|
||||
def test_parse_metadata():
|
||||
metadata = le_audio.Metadata(
|
||||
entries=[
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.PROGRAM_INFO,
|
||||
data=b'',
|
||||
),
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
|
||||
data=bytes([0, 0]),
|
||||
),
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
|
||||
data=bytes([1, 2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert le_audio.Metadata.from_bytes(bytes(metadata)) == metadata
|
||||
132
tests/mcp_test.py
Normal file
132
tests/mcp_test.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import struct
|
||||
import logging
|
||||
|
||||
from bumble import device
|
||||
from bumble.profiles import mcp
|
||||
from tests.test_utils import TwoDevices
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
TIMEOUT = 0.1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GmcsContext:
|
||||
devices: TwoDevices
|
||||
client: mcp.GenericMediaControlServiceProxy
|
||||
server: mcp.GenericMediaControlService
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest_asyncio.fixture
|
||||
async def gmcs_context():
|
||||
devices = TwoDevices()
|
||||
server = mcp.GenericMediaControlService()
|
||||
devices[0].add_service(server)
|
||||
|
||||
await devices.setup_connection()
|
||||
devices.connections[0].encryption = 1
|
||||
devices.connections[1].encryption = 1
|
||||
peer = device.Peer(devices.connections[1])
|
||||
client = await peer.discover_service_and_create_proxy(
|
||||
mcp.GenericMediaControlServiceProxy
|
||||
)
|
||||
await client.subscribe_characteristics()
|
||||
|
||||
return GmcsContext(devices=devices, server=server, client=client)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_media_state(gmcs_context):
|
||||
state = asyncio.Queue()
|
||||
gmcs_context.client.on('media_state', state.put_nowait)
|
||||
|
||||
await gmcs_context.devices[0].notify_subscribers(
|
||||
gmcs_context.server.media_state_characteristic,
|
||||
value=bytes([mcp.MediaState.PLAYING]),
|
||||
)
|
||||
|
||||
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == mcp.MediaState.PLAYING
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_track_title(gmcs_context):
|
||||
state = asyncio.Queue()
|
||||
gmcs_context.client.on('track_title', state.put_nowait)
|
||||
|
||||
await gmcs_context.devices[0].notify_subscribers(
|
||||
gmcs_context.server.track_title_characteristic,
|
||||
value="My Song".encode(),
|
||||
)
|
||||
|
||||
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == "My Song"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_track_duration(gmcs_context):
|
||||
state = asyncio.Queue()
|
||||
gmcs_context.client.on('track_duration', state.put_nowait)
|
||||
|
||||
await gmcs_context.devices[0].notify_subscribers(
|
||||
gmcs_context.server.track_duration_characteristic,
|
||||
value=struct.pack("<i", 1000),
|
||||
)
|
||||
|
||||
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_track_position(gmcs_context):
|
||||
state = asyncio.Queue()
|
||||
gmcs_context.client.on('track_position', state.put_nowait)
|
||||
|
||||
await gmcs_context.devices[0].notify_subscribers(
|
||||
gmcs_context.server.track_position_characteristic,
|
||||
value=struct.pack("<i", 1000),
|
||||
)
|
||||
|
||||
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_media_control_point(gmcs_context):
|
||||
assert (
|
||||
await asyncio.wait_for(
|
||||
gmcs_context.client.write_control_point(mcp.MediaControlPointOpcode.PAUSE),
|
||||
TIMEOUT,
|
||||
)
|
||||
) == mcp.MediaControlPointResultCode.SUCCESS
|
||||
@@ -50,4 +50,6 @@ Example:
|
||||
|
||||
|
||||
NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory.
|
||||
Make a copy of the built `.whl` file in the `web` directory.
|
||||
Make a copy of the built `.whl` file in the `web` directory.
|
||||
|
||||
Tip: During web developement, disable caching. [Chrome](https://stackoverflow.com/a/7000899]) / [Firefiox](https://stackoverflow.com/a/289771)
|
||||
1
web/favicon.ico
Symbolic link
1
web/favicon.ico
Symbolic link
@@ -0,0 +1 @@
|
||||
../docs/images/favicon.ico
|
||||
@@ -15,12 +15,21 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import pyee
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.hci import HCI_Reset_Command
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Scanner:
|
||||
class Scanner(pyee.EventEmitter):
|
||||
"""
|
||||
Scanner web app
|
||||
|
||||
Emitted events:
|
||||
update: Emit when new `ScanEntry` are available.
|
||||
"""
|
||||
|
||||
class ScanEntry:
|
||||
def __init__(self, advertisement):
|
||||
self.address = advertisement.address.to_string(False)
|
||||
@@ -39,13 +48,12 @@ class Scanner:
|
||||
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||
)
|
||||
self.scan_entries = {}
|
||||
self.listeners = {}
|
||||
self.device.on('advertisement', self.on_advertisement)
|
||||
|
||||
async def start(self):
|
||||
print('### Starting Scanner')
|
||||
self.scan_entries = {}
|
||||
self.emit_update()
|
||||
self.emit('update', self.scan_entries)
|
||||
await self.device.power_on()
|
||||
await self.device.start_scanning()
|
||||
print('### Scanner started')
|
||||
@@ -56,16 +64,9 @@ class Scanner:
|
||||
await self.device.power_off()
|
||||
print('### Scanner stopped')
|
||||
|
||||
def emit_update(self):
|
||||
if listener := self.listeners.get('update'):
|
||||
listener(list(self.scan_entries.values()))
|
||||
|
||||
def on(self, event_name, listener):
|
||||
self.listeners[event_name] = listener
|
||||
|
||||
def on_advertisement(self, advertisement):
|
||||
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
|
||||
self.emit_update()
|
||||
self.emit('update', self.scan_entries)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user