Compare commits

...

32 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
5e4055bb6b fix attribute discovery and display 2024-08-11 09:33:53 -07:00
Gilles Boccon-Gibod
4433184048 Merge pull request #522 from google/gbg/rpa2
add basic RPA support
2024-08-06 10:35:39 -07:00
Gilles Boccon-Gibod
312fc8db36 support controller-generated rpa 2024-08-05 08:59:05 -07:00
Gilles Boccon-Gibod
615691ec81 add basic RPA support 2024-08-01 15:37:11 -07:00
zxzxwu
ae8b83f294 Merge pull request #521 from zxzxwu/bap
Add Metadata LTV serializer and adapt Unicast
2024-07-31 11:36:46 +08:00
Josh Wu
4a8e21f4db Add Metadata LTV serializer and adapt Unicast 2024-07-31 01:20:28 +08:00
zxzxwu
3462e7c437 Merge pull request #439 from zxzxwu/mcp
Media Control Service Client implementation
2024-07-24 23:45:00 +08:00
Josh Wu
0f2e5239ad MCP constants and Client implementation 2024-07-24 22:57:26 +08:00
Gilles Boccon-Gibod
ee48cdc63f Merge pull request #517 from AlanRosenthal/scanner_pyee
Update scanner.py to use pyee.EventEmitter
2024-07-18 12:53:00 -07:00
Gilles Boccon-Gibod
1c278bec93 Merge pull request #518 from google/gbg/usb-queue
USB: better packet queue logic
2024-07-18 12:51:00 -07:00
Gilles Boccon-Gibod
6a51166af7 better packet queue logic 2024-07-17 17:48:26 -07:00
Alan Rosenthal
85d79fa914 Update scanner.py to use pyee.EventEmitter 2024-07-17 16:53:50 -04:00
zxzxwu
142bdce94a Merge pull request #515 from zxzxwu/unix
Add UNIX socket transport
2024-07-17 16:04:38 +08:00
Josh Wu
881a5a64b5 Add UNIX socket transport 2024-07-17 00:41:04 +08:00
zxzxwu
5aae44b610 Merge pull request #501 from zxzxwu/exception
Reorganize exceptions
2024-07-12 15:44:58 +08:00
Gilles Boccon-Gibod
e3ea167827 Merge pull request #506 from google/gbg/a2dp-fixes
a2dp: emit delay_report
2024-07-11 18:46:06 -07:00
Gilles Boccon-Gibod
eec145e095 add type hint 2024-07-11 18:39:02 -07:00
Gilles Boccon-Gibod
87fa02d6e5 Merge pull request #507 from google/packageFile
Create `inv web.build`
2024-07-11 18:35:29 -07:00
Gilles Boccon-Gibod
ad94c1e1f3 Merge pull request #509 from AlanRosenthal/discover
device.py: Add discover_all() api
2024-07-11 18:34:29 -07:00
Gilles Boccon-Gibod
546a0bce8d Merge pull request #510 from AlanRosenthal/get_characteristics_by_uuid
device.py: Update get_characteristics_by_uuid()
2024-07-11 18:33:45 -07:00
Gilles Boccon-Gibod
cb7ca44a1c Merge pull request #512 from AlanRosenthal/favicon
Add favicon.ico to docs folder
2024-07-11 18:27:19 -07:00
Gilles Boccon-Gibod
4081b93407 Merge pull request #513 from AlanRosenthal/devcontainer
Add devcontainer.json
2024-07-11 18:24:09 -07:00
Alan Rosenthal
26203ebaad Add devcontainer.json
devcontainer.json allows github's codespaces to be created with bumble's dependencies already installed
2024-07-11 18:47:32 +00:00
Alan Rosenthal
3389e3e1ed device.py: Update get_characteristics_by_uuid()
`get_characteristics_by_uuid()` now allows a UUID to be passed to the
service param. This allows for users to easily query for a service uuid
and characteristic uuid with one API.
2024-07-11 18:05:41 +00:00
Alan Rosenthal
7e1f01c01e Add favicon.ico to docs folder
Generated via: realfavicongenerator.net

validated via:
```
$ icotool -l favicon.ico
--icon --index=1 --width=48 --height=48 --bit-depth=32 --palette-size=0
--icon --index=2 --width=32 --height=32 --bit-depth=32 --palette-size=0
--icon --index=3 --width=16 --height=16 --bit-depth=32 --palette-size=0
```
2024-07-11 09:47:19 -04:00
Gilles Boccon-Gibod
613e15548a Merge pull request #511 from AlanRosenthal/random
console.py: Use Address.generate_static_address
2024-07-10 13:45:52 -07:00
Alan Rosenthal
e09c91df8e console.py: Use Address.generate_static_address 2024-07-10 18:51:46 +00:00
Alan Rosenthal
df206667b6 device.py: Add discover_all() api 2024-07-10 13:24:08 -04:00
Gilles Boccon-Gibod
0f19dd5263 Merge pull request #508 from google/web-readme
Add tip about disabling caching to web's readme
2024-07-09 09:17:25 -07:00
Alan Rosenthal
b98e4937f3 Add tip about disabling caching to web's readme 2024-07-09 13:48:53 +00:00
Gilles Boccon-Gibod
27791cf218 emit delay_report 2024-07-03 13:51:15 -07:00
Josh Wu
f8a2d4f0e0 Reorganize exceptions
* Add BaseBumbleException as a "real" root error
* Add several core error classes and properly replace builtin errors
  with them
* Add several error classes for specific modules (transport, device)
2024-06-11 16:13:08 +08:00
46 changed files with 1667 additions and 315 deletions

View File

@@ -0,0 +1,30 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/universal:2",
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand":
"python -m pip install '.[build,test,development,documentation]'",
// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python"
]
}
}
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.gatt_client import CharacteristicProxy
from bumble.hci import (
Address,
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
@@ -167,6 +168,7 @@ class ConsoleApp:
'remote-services': None,
'local-values': None,
'remote-values': None,
'remote-attributes': None,
},
'filter': {
'address': None,
@@ -215,6 +217,7 @@ class ConsoleApp:
)
self.local_values_text = FormattedTextControl()
self.remote_values_text = FormattedTextControl()
self.remote_attributes_text = FormattedTextControl()
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = []
@@ -241,6 +244,12 @@ class ConsoleApp:
Frame(Window(self.remote_values_text), title='Remote Values'),
filter=Condition(lambda: self.top_tab == 'remote-values'),
),
ConditionalContainer(
Frame(
Window(self.remote_attributes_text), title='Remote Attributes'
),
filter=Condition(lambda: self.top_tab == 'remote-attributes'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log'),
@@ -289,11 +298,7 @@ class ConsoleApp:
device_config, hci_source, hci_sink
)
else:
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
random_address = Address.generate_static_address()
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
@@ -503,19 +508,9 @@ class ConsoleApp:
self.show_error('not connected')
return
# Discover all services, characteristics and descriptors
self.append_to_output('discovering services...')
await self.connected_peer.discover_services()
self.append_to_output(
f'found {len(self.connected_peer.services)} services,'
' discovering characteristics...'
)
await self.connected_peer.discover_characteristics()
self.append_to_output('found characteristics, discovering descriptors...')
for service in self.connected_peer.services:
for characteristic in service.characteristics:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.append_to_output('Service Discovery starting...')
await self.connected_peer.discover_all()
self.append_to_output('Service Discovery done!')
self.show_remote_services(self.connected_peer.services)
@@ -529,7 +524,7 @@ class ConsoleApp:
attributes = await self.connected_peer.discover_attributes()
self.append_to_output(f'discovered {len(attributes)} attributes...')
self.show_attributes(attributes)
await self.show_remote_attributes(attributes)
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
if not self.connected_peer:
@@ -674,7 +669,6 @@ class ConsoleApp:
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except bumble.core.TimeoutError:
self.show_error('connection timed out')
@@ -745,19 +739,20 @@ class ConsoleApp:
'remote-services',
'local-values',
'remote-values',
'remote-attributes',
}:
self.top_tab = params[0]
self.ui.invalidate()
while self.top_tab == 'local-values':
await self.do_show_local_values()
await self.show_local_values()
await asyncio.sleep(1)
while self.top_tab == 'remote-values':
await self.do_show_remote_values()
await self.show_remote_values()
await asyncio.sleep(1)
async def do_show_local_values(self):
async def show_local_values(self):
prettytable = PrettyTable()
field_names = ["Service", "Characteristic", "Descriptor"]
@@ -812,7 +807,7 @@ class ConsoleApp:
self.local_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def do_show_remote_values(self):
async def show_remote_values(self):
prettytable = PrettyTable(
field_names=[
"Connection",
@@ -846,6 +841,23 @@ class ConsoleApp:
self.remote_values_text.text = prettytable.get_string()
self.ui.invalidate()
async def show_remote_attributes(self, attributes):
lines = []
for attribute in attributes:
lines.append(('ansimagenta', str(attribute) + "\n"))
try:
value = await attribute.read_value()
lines.append(('ansicyan', value.hex() + "\n"))
except bumble.core.ProtocolError as error:
lines.append(("ansired", f"!!! Protocol Error ({error})\n"))
except bumble.core.TimeoutError:
lines.append(("ansired", "!!! Timeout\n"))
except Exception as error:
lines.append(("ansired", f"!!! Error ({error})\n"))
self.remote_attributes_text.text = lines
self.ui.invalidate()
async def do_get_phy(self, _):
if not self.connected_peer:
self.show_error('not connected')

View File

@@ -14,13 +14,19 @@
from typing import List, Union
from bumble import core
class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string."""
Raises AtParsingError in case of invalid input string."""
tokens = []
in_quotes = False
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise ValueError("open_paren following regular character")
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise ValueError("quote following regular character")
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string."""
Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren")
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator[-1].append(current)
if len(accumulator) > 1:
raise ValueError("missing close_paren")
raise AtParsingError("missing close_paren")
return accumulator[0]

View File

@@ -20,6 +20,7 @@ import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble import core
from bumble.utils import OpenIntEnum
@@ -88,7 +89,9 @@ class Frame:
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise ValueError(f"invalid subclass name {subclass.__name__}")
raise core.InvalidArgumentError(
f"invalid subclass name {subclass.__name__}"
)
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
@@ -106,7 +109,7 @@ class Frame:
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s")
raise core.InvalidPacketError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
@@ -122,7 +125,7 @@ class Frame:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise ValueError("extended subunit ID value reserved")
raise core.InvalidPacketError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
@@ -131,7 +134,7 @@ class Frame:
opcode_offset = 3
elif subunit_id == 6:
raise ValueError("reserved subunit ID")
raise core.InvalidPacketError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
@@ -448,7 +451,7 @@ class PassThroughFrame:
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes")
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data

View File

@@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import core
from bumble import l2cap
# -----------------------------------------------------------------------------
@@ -275,7 +276,7 @@ class Protocol:
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered")
raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
@@ -287,5 +288,5 @@ class Protocol:
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered")
raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid]

View File

@@ -43,6 +43,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
ProtocolError,
InvalidArgumentError,
name_or_number,
)
from .a2dp import (
@@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
signal_identifier_str = name[:-7]
message_type = Message.MessageType.RESPONSE_REJECT
else:
raise ValueError('invalid class name')
raise InvalidArgumentError('invalid class name')
subclass.message_type = message_type
@@ -2162,6 +2163,9 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
def on_abort_command(self):
self.emit('abort')
def on_delayreport_command(self, delay: int):
self.emit('delay_report', delay)
def on_rtp_channel_open(self):
self.emit('rtp_channel_open')

View File

@@ -55,6 +55,7 @@ from bumble.sdp import (
)
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.core import (
InvalidArgumentError,
ProtocolError,
BT_L2CAP_PROTOCOL_ID,
BT_AVCTP_PROTOCOL_ID,
@@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter):
def notify_track_changed(self, identifier: bytes) -> None:
"""Notify the connected peer of a Track change."""
if len(identifier) != 8:
raise ValueError("identifier must be 8 bytes")
raise InvalidArgumentError("identifier must be 8 bytes")
self.notify_event(TrackChangedEvent(identifier))
def notify_playback_position_changed(self, position: int) -> None:

View File

@@ -18,6 +18,8 @@
from __future__ import annotations
from dataclasses import dataclass
from bumble import core
# -----------------------------------------------------------------------------
class BitReader:
@@ -40,7 +42,7 @@ class BitReader:
""" "Read up to 32 bits."""
if bits > 32:
raise ValueError('maximum read size is 32')
raise core.InvalidArgumentError('maximum read size is 32')
if self.bits_cached >= bits:
# We have enough bits.
@@ -53,7 +55,7 @@ class BitReader:
feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data')
raise core.InvalidArgumentError('trying to read past the data')
self.byte_position += feed_size
# Combine the new cache and the old cache
@@ -68,7 +70,7 @@ class BitReader:
def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data')
raise core.InvalidArgumentError('not enough data')
if self.bit_position % 8:
# Not byte aligned
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
@staticmethod
def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported')
raise core.InvalidPacketError('program_config_element not supported')
@dataclass
class GASpecificConfig:
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported')
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod
def audio_object_type(reader: BitReader):
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
reader, self.channel_configuration, self.audio_object_type
)
else:
raise ValueError(
raise core.InvalidPacketError(
f'audioObjectType {self.audio_object_type} not supported'
)
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
else:
audio_mux_version_a = 0
if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported')
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
num_sub_frames = reader.read(6)
num_program = reader.read(4)
if num_program != 0:
raise ValueError('num_program != 0 not supported')
raise core.InvalidPacketError('num_program != 0 not supported')
num_layer = reader.read(3)
if num_layer != 0:
raise ValueError('num_layer != 0 not supported')
raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
)
audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len')
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len
reader.skip(asc_len)
frame_length_type = reader.read(3)
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
elif frame_length_type == 1:
frame_length = reader.read(9)
else:
raise ValueError(f'frame_length_type {frame_length_type} not supported')
raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1)
if self.other_data_present:
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported')
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported')
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support:

View File

@@ -16,6 +16,10 @@ from functools import partial
from typing import List, Optional, Union
class ColorError(ValueError):
"""Error raised when a color spec is invalid."""
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
raise ColorError('Invalid color spec "%s"' % spec)
def color(
@@ -72,7 +76,7 @@ def color(
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
raise ColorError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)

View File

@@ -79,7 +79,13 @@ def get_dict_key_by_value(dictionary, value):
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
class BaseError(Exception):
class BaseBumbleError(Exception):
"""Base Error raised by Bumble."""
class BaseError(BaseBumbleError):
"""Base class for errors with an error code, error name and namespace"""
def __init__(
@@ -118,18 +124,38 @@ class ProtocolError(BaseError):
"""Protocol Error"""
class TimeoutError(Exception): # pylint: disable=redefined-builtin
class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin
"""Timeout Error"""
class CommandTimeoutError(Exception):
class CommandTimeoutError(BaseBumbleError):
"""Command Timeout Error"""
class InvalidStateError(Exception):
class InvalidStateError(BaseBumbleError):
"""Invalid State Error"""
class InvalidArgumentError(BaseBumbleError, ValueError):
"""Invalid Argument Error"""
class InvalidPacketError(BaseBumbleError, ValueError):
"""Invalid Packet Error"""
class InvalidOperationError(BaseBumbleError, RuntimeError):
"""Invalid Operation Error"""
class OutOfResourcesError(BaseBumbleError, RuntimeError):
"""Out of Resources Error"""
class UnreachableError(BaseBumbleError):
"""The code path raising this error should be unreachable."""
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""
@@ -188,12 +214,12 @@ class UUID:
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
raise ValueError('invalid UUID format')
raise InvalidArgumentError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '')
else:
uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError(f"invalid UUID format: {uuid_str}")
raise InvalidArgumentError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name
@@ -218,7 +244,7 @@ class UUID:
return self.register()
raise ValueError('only 2, 4 and 16 bytes are allowed')
raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:

View File

@@ -27,6 +27,7 @@ import copy
from dataclasses import dataclass, field
from enum import Enum, IntEnum
import functools
import itertools
import json
import logging
import secrets
@@ -178,10 +179,16 @@ from .core import (
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
AdvertisingData,
BaseBumbleError,
ConnectionParameterUpdateError,
CommandTimeoutError,
ConnectionParameters,
ConnectionPHY,
InvalidArgumentError,
InvalidOperationError,
InvalidStateError,
OutOfResourcesError,
UnreachableError,
)
from .utils import (
AsyncRunner,
@@ -253,8 +260,9 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
)
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
# fmt: on
# pylint: enable=line-too-long
@@ -266,6 +274,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectLookupError(BaseBumbleError):
"""Error raised when failed to lookup an object."""
# -----------------------------------------------------------------------------
@@ -1133,6 +1143,15 @@ class Peer:
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes()
async def discover_all(self):
await self.discover_services()
for service in self.services:
await self.discover_characteristics(service=service)
for service in self.services:
for characteristic in service.characteristics:
await self.discover_descriptors(characteristic=characteristic)
async def subscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
@@ -1172,8 +1191,20 @@ class Peer:
return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
self,
uuid: core.UUID,
service: Optional[Union[gatt_client.ServiceProxy, core.UUID]] = None,
) -> List[gatt_client.CharacteristicProxy]:
if isinstance(service, core.UUID):
return list(
itertools.chain(
*[
self.get_characteristics_by_uuid(uuid, s)
for s in self.get_services_by_uuid(service)
]
)
)
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
@@ -1274,6 +1305,7 @@ class Connection(CompositeEventEmitter):
handle: int
transport: int
self_address: Address
self_resolvable_address: Optional[Address]
peer_address: Address
peer_resolvable_address: Optional[Address]
peer_le_features: Optional[LeFeatureMask]
@@ -1321,6 +1353,7 @@ class Connection(CompositeEventEmitter):
handle,
transport,
self_address,
self_resolvable_address,
peer_address,
peer_resolvable_address,
role,
@@ -1332,6 +1365,7 @@ class Connection(CompositeEventEmitter):
self.handle = handle
self.transport = transport
self.self_address = self_address
self.self_resolvable_address = self_resolvable_address
self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only
@@ -1365,6 +1399,7 @@ class Connection(CompositeEventEmitter):
None,
BT_BR_EDR_TRANSPORT,
device.public_address,
None,
peer_address,
None,
role,
@@ -1523,7 +1558,9 @@ class Connection(CompositeEventEmitter):
f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'peer_address={self.peer_address})'
f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
)
@@ -1538,8 +1575,9 @@ class DeviceConfiguration:
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
le_enabled: bool = True
# LE host enable 2nd parameter
le_simultaneous_enabled: bool = False
le_privacy_enabled: bool = False
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
classic_enabled: bool = False
classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True
@@ -1555,6 +1593,7 @@ class DeviceConfiguration:
irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None
address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False
def __post_init__(self) -> None:
@@ -1640,7 +1679,9 @@ def with_connection_from_handle(function):
@functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None:
raise ValueError(f'no connection for handle: 0x{connection_handle:04x}')
raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
)
return function(self, connection, *args, **kwargs)
return wrapper
@@ -1655,7 +1696,7 @@ def with_connection_from_address(function):
for connection in self.connections.values():
if connection.peer_address == address:
return function(self, connection, *args, **kwargs)
raise ValueError('no connection for address')
raise ObjectLookupError('no connection for address')
return wrapper
@@ -1705,8 +1746,9 @@ device_host_event_handlers: List[str] = []
# -----------------------------------------------------------------------------
class Device(CompositeEventEmitter):
# Incomplete list of fields.
random_address: Address
public_address: Address
random_address: Address # Random address that may change with RPA
public_address: Address # Public address (obtained from the controller)
static_address: Address # Random address that can be set but does not change
classic_enabled: bool
name: str
class_of_device: int
@@ -1836,15 +1878,19 @@ class Device(CompositeEventEmitter):
config = config or DeviceConfiguration()
self.config = config
self.public_address = Address('00:00:00:00:00:00')
self.name = config.name
self.public_address = Address.ANY
self.random_address = config.address
self.static_address = config.address
self.class_of_device = config.class_of_device
self.keystore = None
self.irk = config.irk
self.le_enabled = config.le_enabled
self.classic_enabled = config.classic_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
self.classic_enabled = config.classic_enabled
self.cis_enabled = config.cis_enabled
self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
@@ -1853,6 +1899,7 @@ class Device(CompositeEventEmitter):
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload
self.address_generation_offload = config.address_generation_offload
# Extended advertising.
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
@@ -1908,6 +1955,7 @@ class Device(CompositeEventEmitter):
if isinstance(address, str):
address = Address(address)
self.random_address = address
self.static_address = address
# Setup SMP
self.smp_manager = smp.Manager(
@@ -2093,7 +2141,7 @@ class Device(CompositeEventEmitter):
spec=spec,
)
else:
raise ValueError(f'Unexpected mode {spec}')
raise InvalidArgumentError(f'Unexpected mode {spec}')
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -2139,6 +2187,16 @@ class Device(CompositeEventEmitter):
)
if self.le_enabled:
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
logger.info(f'Initial RPA: {self.random_address}')
if self.le_rpa_timeout > 0:
# Start a task to periodically generate a new RPA
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
# Set the controller address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
@@ -2218,9 +2276,45 @@ class Device(CompositeEventEmitter):
async def power_off(self) -> None:
if self.powered_on:
if self.le_rpa_periodic_update_task:
self.le_rpa_periodic_update_task.cancel()
await self.host.flush()
self.powered_on = False
async def update_rpa(self) -> bool:
"""
Try to update the RPA.
Returns:
True if the RPA was updated, False if it could not be updated.
"""
# Check if this is a good time to rotate the address
if self.is_advertising or self.is_scanning or self.is_le_connecting:
logger.debug('skipping RPA update')
return False
random_address = Address.generate_private_address(self.irk)
response = await self.send_command(
HCI_LE_Set_Random_Address_Command(random_address=self.random_address)
)
if response.return_parameters == HCI_SUCCESS:
logger.info(f'new RPA: {random_address}')
self.random_address = random_address
return True
else:
logger.warning(f'failed to set RPA: {response.return_parameters}')
return False
async def _run_rpa_periodic_update(self) -> None:
"""Update the RPA periodically"""
while self.le_rpa_timeout != 0:
await asyncio.sleep(self.le_rpa_timeout)
if not self.update_rpa():
logger.debug("periodic RPA update failed")
async def refresh_resolving_list(self) -> None:
assert self.keystore is not None
@@ -2228,7 +2322,7 @@ class Device(CompositeEventEmitter):
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload:
if self.address_resolution_offload or self.address_generation_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
# Add an empty entry for non-directed address generation.
@@ -2254,7 +2348,7 @@ class Device(CompositeEventEmitter):
def supports_le_features(self, feature: LeFeatureMask) -> bool:
return self.host.supports_le_features(feature)
def supports_le_phy(self, phy):
def supports_le_phy(self, phy: int) -> bool:
if phy == HCI_LE_1M_PHY:
return True
@@ -2263,7 +2357,7 @@ class Device(CompositeEventEmitter):
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
}
if phy not in feature_map:
raise ValueError('invalid PHY')
raise InvalidArgumentError('invalid PHY')
return self.supports_le_features(feature_map[phy])
@@ -2323,7 +2417,7 @@ class Device(CompositeEventEmitter):
# Decide what peer address to use
if advertising_type.is_directed:
if target is None:
raise ValueError('directed advertising requires a target')
raise InvalidArgumentError('directed advertising requires a target')
peer_address = target
else:
peer_address = Address.ANY
@@ -2430,7 +2524,7 @@ class Device(CompositeEventEmitter):
and advertising_data
and scan_response_data
):
raise ValueError(
raise InvalidArgumentError(
"Extended advertisements can't have both data and scan \
response data"
)
@@ -2446,7 +2540,9 @@ class Device(CompositeEventEmitter):
if handle not in self.extended_advertising_sets
)
except StopIteration as exc:
raise RuntimeError("all valid advertising handles already in use") from exc
raise OutOfResourcesError(
"all valid advertising handles already in use"
) from exc
# Use the device's random address if a random address is needed but none was
# provided.
@@ -2545,14 +2641,14 @@ class Device(CompositeEventEmitter):
) -> None:
# Check that the arguments are legal
if scan_interval < scan_window:
raise ValueError('scan_interval must be >= scan_window')
raise InvalidArgumentError('scan_interval must be >= scan_window')
if (
scan_interval < DEVICE_MIN_SCAN_INTERVAL
or scan_interval > DEVICE_MAX_SCAN_INTERVAL
):
raise ValueError('scan_interval out of range')
raise InvalidArgumentError('scan_interval out of range')
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
raise ValueError('scan_interval out of range')
raise InvalidArgumentError('scan_interval out of range')
# Reset the accumulators
self.advertisement_accumulators = {}
@@ -2580,7 +2676,7 @@ class Device(CompositeEventEmitter):
scanning_phy_count += 1
if scanning_phy_count == 0:
raise ValueError('at least one scanning PHY must be enabled')
raise InvalidArgumentError('at least one scanning PHY must be enabled')
await self.send_command(
HCI_LE_Set_Extended_Scan_Parameters_Command(
@@ -2884,7 +2980,7 @@ class Device(CompositeEventEmitter):
# Check parameters
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
raise ValueError('invalid transport')
raise InvalidArgumentError('invalid transport')
# Adjust the transport automatically if we need to
if transport == BT_LE_TRANSPORT and not self.le_enabled:
@@ -2901,7 +2997,7 @@ class Device(CompositeEventEmitter):
peer_address = Address.from_string_for_transport(
peer_address, transport
)
except ValueError:
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(
@@ -2913,7 +3009,7 @@ class Device(CompositeEventEmitter):
transport == BT_BR_EDR_TRANSPORT
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
):
raise ValueError('BR/EDR addresses must be PUBLIC')
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, Address)
@@ -2964,7 +3060,7 @@ class Device(CompositeEventEmitter):
)
)
if not phys:
raise ValueError('at least one supported PHY needed')
raise InvalidArgumentError('at least one supported PHY needed')
phy_count = len(phys)
initiating_phys = phy_list_to_bits(phys)
@@ -3036,7 +3132,7 @@ class Device(CompositeEventEmitter):
)
else:
if HCI_LE_1M_PHY not in connection_parameters_preferences:
raise ValueError('1M PHY preferences required')
raise InvalidArgumentError('1M PHY preferences required')
prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
result = await self.send_command(
@@ -3136,7 +3232,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str):
try:
peer_address = Address(peer_address)
except ValueError:
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(
@@ -3146,7 +3242,7 @@ class Device(CompositeEventEmitter):
assert isinstance(peer_address, Address)
if peer_address == Address.NIL:
raise ValueError('accept on nil address')
raise InvalidArgumentError('accept on nil address')
# Create a future so that we can wait for the request
pending_request_fut = asyncio.get_running_loop().create_future()
@@ -3259,7 +3355,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str):
try:
peer_address = Address(peer_address)
except ValueError:
except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(
@@ -3302,10 +3398,10 @@ class Device(CompositeEventEmitter):
async def set_data_length(self, connection, tx_octets, tx_time) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB:
raise ValueError('tx_octets must be between 0x001B and 0x00FB')
raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
if tx_time < 0x0148 or tx_time > 0x4290:
raise ValueError('tx_time must be between 0x0148 and 0x4290')
raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')
return await self.send_command(
HCI_LE_Set_Data_Length_Command(
@@ -3580,7 +3676,7 @@ class Device(CompositeEventEmitter):
async def encrypt(self, connection, enable=True):
if not enable and connection.transport == BT_LE_TRANSPORT:
raise ValueError('`enable` parameter is classic only.')
raise InvalidArgumentError('`enable` parameter is classic only.')
# Set up event handlers
pending_encryption = asyncio.get_running_loop().create_future()
@@ -3599,11 +3695,11 @@ class Device(CompositeEventEmitter):
if connection.transport == BT_LE_TRANSPORT:
# Look for a key in the key store
if self.keystore is None:
raise RuntimeError('no key store')
raise InvalidOperationError('no key store')
keys = await self.keystore.get(str(connection.peer_address))
if keys is None:
raise RuntimeError('keys not found in key store')
raise InvalidOperationError('keys not found in key store')
if keys.ltk is not None:
ltk = keys.ltk.value
@@ -3614,7 +3710,7 @@ class Device(CompositeEventEmitter):
rand = keys.ltk_central.rand
ediv = keys.ltk_central.ediv
else:
raise RuntimeError('no LTK found for peer')
raise InvalidOperationError('no LTK found for peer')
if connection.role != HCI_CENTRAL_ROLE:
raise InvalidStateError('only centrals can start encryption')
@@ -3889,7 +3985,7 @@ class Device(CompositeEventEmitter):
return cis_link
# Mypy believes this is reachable when context is an ExitStack.
raise InvalidStateError('Unreachable')
raise UnreachableError()
# [LE only]
@experimental('Only for testing.')
@@ -4071,12 +4167,14 @@ class Device(CompositeEventEmitter):
@host_event_handler
def on_connection(
self,
connection_handle,
transport,
peer_address,
role,
connection_parameters,
):
connection_handle: int,
transport: int,
peer_address: Address,
self_resolvable_address: Optional[Address],
peer_resolvable_address: Optional[Address],
role: int,
connection_parameters: ConnectionParameters,
) -> None:
logger.debug(
f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
@@ -4097,15 +4195,15 @@ class Device(CompositeEventEmitter):
return
# Resolve the peer address if we can
peer_resolvable_address = None
if self.address_resolver:
if peer_address.is_resolvable:
resolved_address = self.address_resolver.resolve(peer_address)
if resolved_address is not None:
logger.debug(f'*** Address resolved as {resolved_address}')
peer_resolvable_address = peer_address
peer_address = resolved_address
if peer_resolvable_address is None:
# Resolve the peer address if we can
if self.address_resolver:
if peer_address.is_resolvable:
resolved_address = self.address_resolver.resolve(peer_address)
if resolved_address is not None:
logger.debug(f'*** Address resolved as {resolved_address}')
peer_resolvable_address = peer_address
peer_address = resolved_address
self_address = None
if role == HCI_CENTRAL_ROLE:
@@ -4136,12 +4234,19 @@ class Device(CompositeEventEmitter):
else self.random_address
)
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if peer_resolvable_address == Address.ANY_RANDOM:
peer_resolvable_address = None
# Create a connection.
connection = Connection(
self,
connection_handle,
transport,
self_address,
self_resolvable_address,
peer_address,
peer_resolvable_address,
role,
@@ -4152,9 +4257,10 @@ class Device(CompositeEventEmitter):
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
if self.legacy_advertiser.auto_restart:
advertiser = self.legacy_advertiser
connection.once(
'disconnection',
lambda _: self.abort_on('flush', self.legacy_advertiser.start()),
lambda _: self.abort_on('flush', advertiser.start()),
)
else:
self.legacy_advertiser = None
@@ -4377,7 +4483,7 @@ class Device(CompositeEventEmitter):
return await pairing_config.delegate.confirm(auto=True)
async def na() -> bool:
assert False, "N/A: unreachable"
raise UnreachableError()
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
methods = {
@@ -4838,5 +4944,6 @@ class Device(CompositeEventEmitter):
return (
f'Device(name="{self.name}", '
f'random_address="{self.random_address}", '
f'public_address="{self.public_address}")'
f'public_address="{self.public_address}", '
f'static_address="{self.static_address}")'
)

View File

@@ -33,6 +33,7 @@ from typing import Tuple
import weakref
from bumble import core
from bumble.hci import (
hci_vendor_command_op_code,
STATUS_SPEC,
@@ -49,6 +50,10 @@ from bumble.drivers import common
logger = logging.getLogger(__name__)
class RtkFirmwareError(core.BaseBumbleError):
"""Error raised when RTK firmware initialization fails."""
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -208,15 +213,15 @@ class Firmware:
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise ValueError("Firmware does not start with epatch signature")
raise RtkFirmwareError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig):
raise ValueError("Firmware does not end with extension sig")
raise RtkFirmwareError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header.
epatch_header_size = 14
if len(firmware) < epatch_header_size:
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
# Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig)
@@ -230,7 +235,7 @@ class Firmware:
break
if length == 0:
raise ValueError("Invalid 0-length instruction")
raise RtkFirmwareError("Invalid 0-length instruction")
if opcode == 0 and length == 1:
project_id = firmware[offset - 1]
@@ -239,7 +244,7 @@ class Firmware:
offset -= length
if project_id < 0:
raise ValueError("Project ID not found")
raise RtkFirmwareError("Project ID not found")
self.project_id = project_id
@@ -252,7 +257,7 @@ class Firmware:
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware):
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
@@ -266,7 +271,7 @@ class Firmware:
"<I", firmware, patch_offset_table_offset + 4 * patch_index
)
if patch_offset + patch_length > len(firmware):
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
# Get the SVN version for the patch
(svn_version,) = struct.unpack_from(
@@ -645,7 +650,7 @@ class Driver(common.Driver):
):
return await self.download_for_rtl8723b()
raise ValueError("ROM not supported")
raise RtkFirmwareError("ROM not supported")
async def init_controller(self):
await self.download_firmware()

View File

@@ -331,9 +331,9 @@ class Client:
async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF:
raise ValueError('MTU must be <= 0xFFFF')
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection
if self.mtu_exchange_done:
@@ -977,6 +977,7 @@ class Client:
offset += len(part)
self.cache_value(attribute_handle, attribute_value)
# Return the value as bytes
return attribute_value

View File

@@ -31,6 +31,8 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT,
AdvertisingData,
DeviceClass,
InvalidArgumentError,
InvalidPacketError,
ProtocolError,
bit_flags_to_strings,
name_or_number,
@@ -92,14 +94,14 @@ def map_class_of_device(class_of_device):
)
def phy_list_to_bits(phys):
def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int:
if phys is None:
return 0
phy_bits = 0
for phy in phys:
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
raise ValueError('invalid PHY')
raise InvalidArgumentError('invalid PHY')
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
return phy_bits
@@ -1553,7 +1555,7 @@ class HCI_Object:
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}')
raise InvalidArgumentError(f'unknown field type {field_type}')
@staticmethod
def dict_from_bytes(data, offset, fields):
@@ -1622,7 +1624,7 @@ class HCI_Object:
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
raise InvalidArgumentError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif field_type == 'v':
@@ -1641,7 +1643,9 @@ class HCI_Object:
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(f"don't know how to serialize type {type(field_value)}")
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes
@@ -1835,6 +1839,12 @@ class Address:
data, offset, Address.PUBLIC_DEVICE_ADDRESS
)
@staticmethod
def parse_random_address(data, offset):
return Address.parse_address_with_type(
data, offset, Address.RANDOM_DEVICE_ADDRESS
)
@staticmethod
def parse_address_with_type(data, offset, address_type):
return offset + 6, Address(data[offset : offset + 6], address_type)
@@ -1905,7 +1915,7 @@ class Address:
self.address_bytes = bytes(reversed(bytes.fromhex(address)))
if len(self.address_bytes) != 6:
raise ValueError('invalid address length')
raise InvalidArgumentError('invalid address length')
self.address_type = address_type
@@ -1961,7 +1971,8 @@ class Address:
def __eq__(self, other):
return (
self.address_bytes == other.address_bytes
isinstance(other, Address)
and self.address_bytes == other.address_bytes
and self.is_public == other.is_public
)
@@ -2108,7 +2119,7 @@ class HCI_Command(HCI_Packet):
op_code, length = struct.unpack_from('<HB', packet, 1)
parameters = packet[4:]
if len(parameters) != length:
raise ValueError('invalid packet length')
raise InvalidPacketError('invalid packet length')
# Look for a registered class
cls = HCI_Command.command_classes.get(op_code)
@@ -4807,7 +4818,7 @@ class HCI_Event(HCI_Packet):
length = packet[2]
parameters = packet[3:]
if len(parameters) != length:
raise ValueError('invalid packet length')
raise InvalidPacketError('invalid packet length')
cls: Any
if event_code == HCI_LE_META_EVENT:
@@ -5174,8 +5185,8 @@ class HCI_LE_Data_Length_Change_Event(HCI_LE_Meta_Event):
),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('local_resolvable_private_address', Address.parse_address),
('peer_resolvable_private_address', Address.parse_address),
('local_resolvable_private_address', Address.parse_random_address),
('peer_resolvable_private_address', Address.parse_random_address),
('connection_interval', 2),
('peripheral_latency', 2),
('supervision_timeout', 2),
@@ -6342,7 +6353,7 @@ class HCI_AclDataPacket(HCI_Packet):
bc_flag = (h >> 14) & 3
data = packet[5:]
if len(data) != data_total_length:
raise ValueError('invalid packet length')
raise InvalidPacketError('invalid packet length')
return HCI_AclDataPacket(
connection_handle, pb_flag, bc_flag, data_total_length, data
)
@@ -6390,7 +6401,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
packet_status = (h >> 12) & 0b11
data = packet[4:]
if len(data) != data_total_length:
raise ValueError(
raise InvalidPacketError(
f'invalid packet length {len(data)} != {data_total_length}'
)
return HCI_SynchronousDataPacket(

View File

@@ -772,6 +772,8 @@ class Host(AbortableEventEmitter):
event.connection_handle,
BT_LE_TRANSPORT,
event.peer_address,
getattr(event, 'local_resolvable_private_address', None),
getattr(event, 'peer_resolvable_private_address', None),
event.role,
connection_parameters,
)
@@ -817,6 +819,8 @@ class Host(AbortableEventEmitter):
event.bd_addr,
None,
None,
None,
None,
)
else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')

View File

@@ -41,7 +41,14 @@ from typing import (
from .utils import deprecated
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .core import (
BT_CENTRAL_ROLE,
InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
OutOfResourcesError,
ProtocolError,
)
from .hci import (
HCI_LE_Connection_Update_Command,
HCI_Object,
@@ -189,17 +196,17 @@ class LeCreditBasedChannelSpec:
self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
raise InvalidArgumentError('max credits out of range')
if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
raise InvalidArgumentError('MTU out of range')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
raise InvalidArgumentError('MPS out of range')
class L2CAP_PDU:
@@ -211,7 +218,7 @@ class L2CAP_PDU:
def from_bytes(data: bytes) -> L2CAP_PDU:
# Check parameters
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
raise InvalidPacketError('not enough data for L2CAP header')
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
l2cap_pdu_payload = data[4:]
@@ -816,7 +823,7 @@ class ClassicChannel(EventEmitter):
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
raise InvalidStateError('connection already pending')
self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
@@ -1129,7 +1136,7 @@ class LeCreditBasedChannel(EventEmitter):
# Check that we can start a new connection
identifier = self.manager.next_identifier(self.connection)
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
raise InvalidStateError('too many concurrent connection requests')
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
@@ -1516,7 +1523,7 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID available')
raise OutOfResourcesError('no free CID available')
@staticmethod
def find_free_le_cid(channels: Iterable[int]) -> int:
@@ -1529,7 +1536,7 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID')
raise OutOfResourcesError('no free CID')
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
@@ -1576,15 +1583,15 @@ class ChannelManager:
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
raise ValueError('PSM already in use')
raise InvalidArgumentError('PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
raise InvalidArgumentError('invalid PSM (not odd)')
check = spec.psm >> 8
while check:
if check % 2 != 0:
raise ValueError('invalid PSM')
raise InvalidArgumentError('invalid PSM')
check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
@@ -1626,7 +1633,7 @@ class ChannelManager:
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use')
raise InvalidArgumentError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
@@ -2154,10 +2161,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
raise InvalidArgumentError('PSM cannot be None')
# Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
@@ -2206,10 +2213,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
raise InvalidArgumentError('PSM cannot be None')
# Create the channel
logger.debug(

View File

@@ -19,7 +19,12 @@ import logging
import asyncio
from functools import partial
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.core import (
BT_PERIPHERAL_ROLE,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError,
)
from bumble.colors import color
from bumble.hci import (
Address,
@@ -405,12 +410,12 @@ class RemoteLink:
def add_controller(self, controller):
if self.controller:
raise ValueError('controller already set')
raise InvalidStateError('controller already set')
self.controller = controller
def remove_controller(self, controller):
if self.controller != controller:
raise ValueError('controller mismatch')
raise InvalidStateError('controller mismatch')
self.controller = None
def get_pending_connection(self):

View File

@@ -685,10 +685,11 @@ class CodecSpecificConfiguration:
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
# TODO: Parse Metadata
metadata: bytes = b''
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
@@ -701,7 +702,8 @@ class PacRecord:
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
metadata = data[offset : offset + metadata_size]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
@@ -719,12 +721,13 @@ class PacRecord:
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(self.metadata)])
+ self.metadata
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
@@ -940,8 +943,7 @@ class AseStateMachine(gatt.Characteristic):
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
# TODO: Parse this
metadata = b''
metadata = le_audio.Metadata()
def __init__(
self,
@@ -1088,7 +1090,7 @@ class AseStateMachine(gatt.Characteristic):
AseReasonCode.NONE,
)
self.metadata = metadata
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@@ -1140,7 +1142,7 @@ class AseStateMachine(gatt.Characteristic):
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
@@ -1217,8 +1219,9 @@ class AseStateMachine(gatt.Characteristic):
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''

View File

@@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
raise core.InvalidArgumentError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
@@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
@@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
raise core.InvalidPacketError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
@@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])

View File

@@ -19,6 +19,7 @@
from enum import IntEnum
import struct
from bumble import core
from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error
from ..gatt import (
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected

View File

@@ -17,33 +17,67 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import List
import struct
from typing import List, Type
from typing_extensions import Self
from bumble import utils
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
again outside the lib.
'''
class Tag(utils.OpenIntEnum):
# fmt: off
PREFERRED_AUDIO_CONTEXTS = 0x01
STREAMING_AUDIO_CONTEXTS = 0x02
PROGRAM_INFO = 0x03
LANGUAGE = 0x04
CCID_LIST = 0x05
PARENTAL_RATING = 0x06
PROGRAM_INFO_URI = 0x07
AUDIO_ACTIVE_STATE = 0x08
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
ASSISTED_LISTENING_STREAM = 0x0A
BROADCAST_NAME = 0x0B
EXTENDED_METADATA = 0xFE
VENDOR_SPECIFIC = 0xFF
@dataclasses.dataclass
class Entry:
tag: int
tag: Metadata.Tag
data: bytes
entries: List[Entry]
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
def __bytes__(self) -> bytes:
return bytes([len(self.data) + 1, self.tag]) + self.data
entries: List[Entry] = dataclasses.field(default_factory=list)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = []
offset = 0
length = len(data)
while length >= 2:
while offset < length:
entry_length = data[offset]
entry_tag = data[offset + 1]
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
entries.append(cls.Entry(entry_tag, entry_data))
length -= entry_length
offset += 1
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
offset += entry_length
return cls(entries)
def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries])

448
bumble/profiles/mcp.py Normal file
View File

@@ -0,0 +1,448 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class PlayingOrder(utils.OpenIntEnum):
'''See Media Control Service 3.15. Playing Order.'''
SINGLE_ONCE = 0x01
SINGLE_REPEAT = 0x02
IN_ORDER_ONCE = 0x03
IN_ORDER_REPEAT = 0x04
OLDEST_ONCE = 0x05
OLDEST_REPEAT = 0x06
NEWEST_ONCE = 0x07
NEWEST_REPEAT = 0x08
SHUFFLE_ONCE = 0x09
SHUFFLE_REPEAT = 0x0A
class PlayingOrderSupported(enum.IntFlag):
'''See Media Control Service 3.16. Playing Orders Supported.'''
SINGLE_ONCE = 0x0001
SINGLE_REPEAT = 0x0002
IN_ORDER_ONCE = 0x0004
IN_ORDER_REPEAT = 0x0008
OLDEST_ONCE = 0x0010
OLDEST_REPEAT = 0x0020
NEWEST_ONCE = 0x0040
NEWEST_REPEAT = 0x0080
SHUFFLE_ONCE = 0x0100
SHUFFLE_REPEAT = 0x0200
class MediaState(utils.OpenIntEnum):
'''See Media Control Service 3.17. Media State.'''
INACTIVE = 0x00
PLAYING = 0x01
PAUSED = 0x02
SEEKING = 0x03
class MediaControlPointOpcode(utils.OpenIntEnum):
'''See Media Control Service 3.18. Media Control Point.'''
PLAY = 0x01
PAUSE = 0x02
FAST_REWIND = 0x03
FAST_FORWARD = 0x04
STOP = 0x05
MOVE_RELATIVE = 0x10
PREVIOUS_SEGMENT = 0x20
NEXT_SEGMENT = 0x21
FIRST_SEGMENT = 0x22
LAST_SEGMENT = 0x23
GOTO_SEGMENT = 0x24
PREVIOUS_TRACK = 0x30
NEXT_TRACK = 0x31
FIRST_TRACK = 0x32
LAST_TRACK = 0x33
GOTO_TRACK = 0x34
PREVIOUS_GROUP = 0x40
NEXT_GROUP = 0x41
FIRST_GROUP = 0x42
LAST_GROUP = 0x43
GOTO_GROUP = 0x44
class MediaControlPointResultCode(enum.IntFlag):
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
SUCCESS = 0x01
OPCODE_NOT_SUPPORTED = 0x02
MEDIA_PLAYER_INACTIVE = 0x03
COMMAND_CANNOT_BE_COMPLETED = 0x04
class MediaControlPointOpcodeSupported(enum.IntFlag):
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
PLAY = 0x00000001
PAUSE = 0x00000002
FAST_REWIND = 0x00000004
FAST_FORWARD = 0x00000008
STOP = 0x00000010
MOVE_RELATIVE = 0x00000020
PREVIOUS_SEGMENT = 0x00000040
NEXT_SEGMENT = 0x00000080
FIRST_SEGMENT = 0x00000100
LAST_SEGMENT = 0x00000200
GOTO_SEGMENT = 0x00000400
PREVIOUS_TRACK = 0x00000800
NEXT_TRACK = 0x00001000
FIRST_TRACK = 0x00002000
LAST_TRACK = 0x00004000
GOTO_TRACK = 0x00008000
PREVIOUS_GROUP = 0x00010000
NEXT_GROUP = 0x00020000
FIRST_GROUP = 0x00040000
LAST_GROUP = 0x00080000
GOTO_GROUP = 0x00100000
class SearchControlPointItemType(utils.OpenIntEnum):
'''See Media Control Service 3.20. Search Control Point.'''
TRACK_NAME = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
GROUP_NAME = 0x04
EARLIEST_YEAR = 0x05
LATEST_YEAR = 0x06
GENRE = 0x07
ONLY_TRACKS = 0x08
ONLY_GROUPS = 0x09
class ObjectType(utils.OpenIntEnum):
'''See Media Control Service 4.4.1. Object Type field.'''
TASK = 0
GROUP = 1
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectId(int):
'''See Media Control Service 4.4.2. Object ID field.'''
@classmethod
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(int.from_bytes(data, byteorder='little', signed=False))
def __bytes__(self) -> bytes:
return self.to_bytes(6, 'little')
@dataclasses.dataclass
class GroupObjectType:
'''See Media Control Service 4.4. Group Object Type.'''
object_type: ObjectType
object_id: ObjectId
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(
object_type=ObjectType(data[0]),
object_id=ObjectId.create_from_bytes(data[1:]),
)
def __bytes__(self) -> bytes:
return bytes([self.object_type]) + bytes(self.object_id)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class MediaControlService(gatt.TemplateService):
'''Media Control Service server implementation, only for testing currently.'''
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player',
)
self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_title_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_duration_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_position_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_state_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_control_point_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.content_control_id_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
super().__init__(
[
self.media_player_name_characteristic,
self.track_changed_characteristic,
self.track_title_characteristic,
self.track_duration_characteristic,
self.track_position_characteristic,
self.media_state_characteristic,
self.media_control_point_characteristic,
self.media_control_point_opcodes_supported_characteristic,
self.content_control_id_characteristic,
]
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(
connection,
self.media_control_point_characteristic,
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
)
class GenericMediaControlService(MediaControlService):
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class MediaControlServiceProxy(
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
):
SERVICE_CLASS = MediaControlService
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
}
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
track_changed: Optional[gatt_client.CharacteristicProxy] = None
track_title: Optional[gatt_client.CharacteristicProxy] = None
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
media_state: Optional[gatt_client.CharacteristicProxy] = None
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
None
)
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
utils.CompositeEventEmitter.__init__(self)
self.service_proxy = service_proxy
self.lock = asyncio.Lock()
self.media_control_point_notifications = asyncio.Queue()
for field, uuid in self._CHARACTERISTICS.items():
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, field, characteristics[0])
async def subscribe_characteristics(self) -> None:
if self.media_control_point:
await self.media_control_point.subscribe(self._on_media_control_point)
if self.media_state:
await self.media_state.subscribe(self._on_media_state)
if self.track_changed:
await self.track_changed.subscribe(self._on_track_changed)
if self.track_title:
await self.track_title.subscribe(self._on_track_title)
if self.track_duration:
await self.track_duration.subscribe(self._on_track_duration)
if self.track_position:
await self.track_position.subscribe(self._on_track_position)
async def write_control_point(
self, opcode: MediaControlPointOpcode
) -> MediaControlPointResultCode:
'''Writes a Media Control Point Opcode to peer and waits for the notification.
The write operation will be executed when there isn't other pending commands.
Args:
opcode: opcode defined in `MediaControlPointOpcode`.
Returns:
Response code provided in `MediaControlPointResultCode`
Raises:
InvalidOperationError: Server does not have Media Control Point Characteristic.
InvalidStateError: Server replies a notification with mismatched opcode.
'''
if not self.media_control_point:
raise core.InvalidOperationError("Peer does not have media control point")
async with self.lock:
await self.media_control_point.write_value(
bytes([opcode]),
with_response=False,
)
(
response_opcode,
response_code,
) = await self.media_control_point_notifications.get()
if response_opcode != opcode:
raise core.InvalidStateError(
f"Expected {opcode} notification, but get {response_opcode}"
)
return MediaControlPointResultCode(response_code)
def _on_media_control_point(self, data: bytes) -> None:
self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None:
del data
self.emit('track_changed')
def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
SERVICE_CLASS = GenericMediaControlService

View File

@@ -36,7 +36,9 @@ from .core import (
BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID,
InvalidArgumentError,
InvalidStateError,
InvalidPacketError,
ProtocolError,
)
@@ -335,7 +337,7 @@ class RFCOMM_Frame:
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs:
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch')
raise InvalidPacketError('fcs mismatch')
return frame
@@ -713,7 +715,7 @@ class DLC(EventEmitter):
# Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8')
else:
raise ValueError('write only accept bytes or strings')
raise InvalidArgumentError('write only accept bytes or strings')
self.tx_buffer += data
self.drained.clear()

View File

@@ -23,7 +23,7 @@ from typing_extensions import Self
from . import core, l2cap
from .colors import color
from .core import InvalidStateError
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING:
@@ -189,7 +189,9 @@ class DataElement:
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise ValueError('integer types must have a value size specified')
raise InvalidArgumentError(
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
@@ -265,7 +267,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
@@ -281,7 +283,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
@@ -354,7 +356,7 @@ class DataElement:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
@@ -365,7 +367,7 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
@@ -376,7 +378,7 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
@@ -392,7 +394,7 @@ class DataElement:
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise ValueError('NIL must be empty')
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
@@ -410,7 +412,7 @@ class DataElement:
elif size == 16:
size_index = 4
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
@@ -427,10 +429,10 @@ class DataElement:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise ValueError('boolean must be 1 byte')
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data

View File

@@ -55,6 +55,7 @@ from .core import (
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData,
InvalidArgumentError,
ProtocolError,
name_or_number,
)
@@ -766,8 +767,11 @@ class Session:
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses
self_address = connection.self_address
self_address = connection.self_resolvable_address or connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address
logger.debug(
f"pairing with self_address={self_address}, peer_address={peer_address}"
)
if self.is_initiator:
self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0
@@ -784,7 +788,7 @@ class Session:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise ValueError(
raise InvalidArgumentError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
@@ -793,7 +797,7 @@ class Session:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise ValueError(
raise InvalidArgumentError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
@@ -1075,9 +1079,9 @@ class Session:
def send_identity_address_command(self) -> None:
identity_address = {
None: self.connection.self_address,
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(

View File

@@ -23,6 +23,7 @@ import datetime
from typing import BinaryIO, Generator
import os
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
raise core.InvalidArgumentError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')

View File

@@ -20,7 +20,7 @@ import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
@@ -180,7 +180,13 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme')
if scheme == 'unix':
from .unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
raise TransportSpecError('unknown transport scheme')
# -----------------------------------------------------------------------------

View File

@@ -20,7 +20,13 @@ import grpc.aio
from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
from .common import (
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
elif ':' in param:
server_host, server_port = param.split(':')
else:
raise ValueError('invalid parameter')
raise TransportSpecError('invalid parameter')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
service = VhciForwardingServiceStub(channel)
hci_device = HciDevice(service.attachVhci())
else:
raise ValueError('invalid mode')
raise TransportSpecError('invalid mode')
# Create the transport object
class EmulatorTransport(PumpedTransport):

View File

@@ -31,6 +31,8 @@ from .common import (
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
TransportInitError,
)
# pylint: disable=no-name-in-module
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport:
if not server_port:
raise ValueError('invalid port')
raise TransportSpecError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number)
if not server_port:
raise RuntimeError('gRPC server port not found')
raise TransportInitError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
if response_type == 'error':
logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error)
raise TransportInitError(response.error)
if response_type == 'hci_packet':
return (
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
+ response.hci_packet.packet
)
raise ValueError('unsupported response type')
raise TransportSpecError('unsupported response type')
async def write(self, packet):
await self.hci_device.write(
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
options: Dict[str, str] = {}
for param in params[params_offset:]:
if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>')
raise TransportSpecError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=')
options[option_name] = option_value
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
)
if mode == 'controller':
if host is None:
raise ValueError('<host>:<port> missing')
raise TransportSpecError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option')
raise TransportSpecError('invalid mode option')

View File

@@ -23,6 +23,7 @@ import logging
import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import core
from bumble import hci
from bumble.colors import color
from bumble.snoop import Snooper
@@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
# -----------------------------------------------------------------------------
# Errors
# -----------------------------------------------------------------------------
class TransportLostError(Exception):
"""
The Transport has been lost/disconnected.
"""
class TransportLostError(core.BaseBumbleError, RuntimeError):
"""The Transport has been lost/disconnected."""
class TransportInitError(core.BaseBumbleError, RuntimeError):
"""Error raised when the transport cannot be initialized."""
class TransportSpecError(core.BaseBumbleError, ValueError):
"""Error raised when the transport spec is invalid."""
# -----------------------------------------------------------------------------
@@ -132,7 +139,9 @@ class PacketParser:
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
raise core.InvalidPacketError(
f'invalid packet type {packet_type}'
)
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
@@ -178,19 +187,19 @@ class PacketReader:
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size)
if len(header) != header_size:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
# Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length)
if len(body) != body_length:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
return packet_type + header + body
@@ -211,7 +220,7 @@ class AsyncPacketReader:
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
@@ -420,7 +429,7 @@ class SnoopingTransport(Transport):
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
raise core.UnreachableError() # Satisfy the type checker
class Source:
sink: TransportSink

View File

@@ -29,7 +29,7 @@ from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource
from .common import Transport, ParserSource, TransportInitError
from .. import hci
from ..colors import color
@@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
device = None
if device is None:
raise ValueError('device not found')
raise TransportInitError('device not found')
logger.debug(f'USB Device: {device}')
# Power Cycle the device

56
bumble/transport/unix.py Normal file
View File

@@ -0,0 +1,56 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_unix_client_transport(spec: str) -> Transport:
'''Open a UNIX socket client transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
class UnixPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.on_transport_lost()
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
(
unix_transport,
packet_source,
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink)

View File

@@ -15,19 +15,18 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import threading
import collections
import ctypes
import platform
import usb1
from bumble.transport.common import Transport, ParserSource
from bumble.transport.common import Transport, ParserSource, TransportInitError
from bumble import hci
from bumble.colors import color
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
@@ -115,13 +114,17 @@ async def open_usb_transport(spec: str) -> Transport:
self.device = device
self.acl_out = acl_out
self.acl_out_transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.acl_out_transfer_ready = asyncio.Semaphore(1)
self.packets: asyncio.Queue[bytes] = (
asyncio.Queue()
) # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.queue_task = None
self.cancel_done = self.loop.create_future()
self.closed = False
def start(self):
pass
self.queue_task = asyncio.create_task(self.process_queue())
def on_packet(self, packet):
# Ignore packets if we're closed
@@ -133,62 +136,64 @@ async def open_usb_transport(spec: str) -> Transport:
return
# Queue the packet
self.packets.append(packet)
if len(self.packets) == 1:
# The queue was previously empty, re-prime the pump
self.process_queue()
self.packets.put_nowait(packet)
def transfer_callback(self, transfer):
self.acl_out_transfer_ready.release()
status = transfer.getStatus()
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent)
elif status == usb1.TRANSFER_CANCELLED:
if status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
return
if status != usb1.TRANSFER_COMPLETED:
logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red')
)
def on_packet_sent(self):
if self.packets:
self.packets.popleft()
self.process_queue()
async def process_queue(self):
while True:
# Wait for a packet to transfer.
packet = await self.packets.get()
def process_queue(self):
if len(self.packets) == 0:
return # Nothing to do
# Wait until we can start a transfer.
await self.acl_out_transfer_ready.acquire()
packet = self.packets[0]
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback
)
self.acl_out_transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.transfer_callback,
)
self.acl_out_transfer.submit()
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
# Transfer the packet.
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback
)
self.acl_out_transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.transfer_callback,
)
self.acl_out_transfer.submit()
else:
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
def close(self):
self.closed = True
if self.queue_task:
self.queue_task.cancel()
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data
self.packets.clear()
while not self.packets.empty():
self.packets.get_nowait()
# If we have a transfer in flight, cancel it
if self.acl_out_transfer.isSubmitted():
@@ -442,7 +447,7 @@ async def open_usb_transport(spec: str) -> Transport:
if found is None:
context.close()
raise ValueError('device not found')
raise TransportInitError('device not found')
logger.debug(f'USB Device: {found}')
@@ -507,7 +512,7 @@ async def open_usb_transport(spec: str) -> Transport:
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
raise TransportInitError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '

BIN
docs/images/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -0,0 +1,7 @@
{
"name": "Bumble",
"address": "F0:F1:F2:F3:F4:F5",
"keystore": "JsonKeyStore",
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
"le_privacy_enabled": true
}

View File

@@ -3,5 +3,6 @@
"keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"class_of_device": 2376708,
"cis_enabled": true,
"advertising_interval": 100
}

83
examples/mcp_server.html Normal file
View File

@@ -0,0 +1,83 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble LEA Media Control Client</span>
</div>
</nav>
<br>
<div class="container">
<label class="form-label">Server Port</label>
<div class="input-group mb-3">
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
</div>
<button class="btn btn-primary" onclick="send_opcode(0x01)">Play</button>
<button class="btn btn-primary" onclick="send_opcode(0x02)">Pause</button>
<button class="btn btn-primary" onclick="send_opcode(0x03)">Fast Rewind</button>
<button class="btn btn-primary" onclick="send_opcode(0x04)">Fast Forward</button>
<button class="btn btn-primary" onclick="send_opcode(0x05)">Stop</button>
</br></br>
<button class="btn btn-primary" onclick="send_opcode(0x30)">Previous Track</button>
<button class="btn btn-primary" onclick="send_opcode(0x31)">Next Track</button>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let portInput = document.getElementById("port")
let log = document.getElementById("log")
let socket
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
log.textContent += 'OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = (event) => {
log.textContent += `<-- ${event.data}\n`
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
let jsonMessage = JSON.stringify(message)
log.textContent += `--> ${jsonMessage}\n`
socket.send(jsonMessage)
} else {
log.textContent += 'NOT CONNECTED\n'
}
}
function send_opcode(opcode) {
send({ 'opcode': opcode })
}
</script>
</div>
</body>
</html>

196
examples/run_mcp_client.py Normal file
View File

@@ -0,0 +1,196 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
import websockets
import json
from bumble.core import AdvertisingData
from bumble.device import (
Device,
AdvertisingParameters,
AdvertisingEventProperties,
Connection,
Peer,
)
from bumble.hci import (
CodecID,
CodingFormat,
OwnAddressType,
)
from bumble.profiles.bap import (
CodecSpecificCapabilities,
ContextType,
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
UnicastServerAdvertisingData,
)
from bumble.profiles.mcp import (
MediaControlServiceProxy,
GenericMediaControlServiceProxy,
MediaState,
MediaControlPointOpcode,
)
from bumble.transport import open_transport_or_link
from typing import Optional
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_mcp_client.py <config-file>' '<transport-spec-for-device>')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
# Add "placeholder" services to enable Android LEA features.
device.add_service(
PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED,
available_source_context=ContextType.PROHIBITED,
supported_sink_context=ContextType.MEDIA,
available_sink_context=ContextType.MEDIA,
sink_audio_locations=(
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
),
sink_pac=[
PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_16000
| SupportedSamplingFrequency.FREQ_32000
| SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=0,
max_octets_per_codec_frame=320,
supported_max_codec_frames_per_sdu=2,
),
),
],
)
)
device.add_service(AudioStreamControlService(device, sink_ase_id=[1]))
ws: Optional[websockets.WebSocketServerProtocol] = None
mcp: Optional[MediaControlServiceProxy] = None
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(UnicastServerAdvertisingData())
await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(
advertising_event_properties=AdvertisingEventProperties(),
own_address_type=OwnAddressType.RANDOM,
primary_advertising_interval_max=100,
primary_advertising_interval_min=100,
),
advertising_data=advertising_data,
auto_restart=True,
)
def on_media_state(media_state: MediaState) -> None:
if ws:
asyncio.create_task(
ws.send(json.dumps({'media_state': media_state.name}))
)
def on_track_title(title: str) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'title': title})))
def on_track_duration(duration: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'duration': duration})))
def on_track_position(position: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'position': position})))
def on_connection(connection: Connection) -> None:
async def on_connection_async():
async with Peer(connection) as peer:
nonlocal mcp
mcp = peer.create_service_proxy(MediaControlServiceProxy)
if not mcp:
mcp = peer.create_service_proxy(GenericMediaControlServiceProxy)
mcp.on('media_state', on_media_state)
mcp.on('track_title', on_track_title)
mcp.on('track_duration', on_track_duration)
mcp.on('track_position', on_track_position)
await mcp.subscribe_characteristics()
connection.abort_on('disconnection', on_connection_async())
device.on('connection', on_connection)
async def serve(websocket: websockets.WebSocketServerProtocol, _path):
nonlocal ws
ws = websocket
async for message in websocket:
request = json.loads(message)
if mcp:
await mcp.write_control_point(
MediaControlPointOpcode(request['opcode'])
)
ws = None
await websockets.serve(serve, 'localhost', 8989)
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -48,6 +48,7 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)
from bumble.profiles.le_audio import Metadata
from tests.test_utils import TwoDevices
@@ -97,7 +98,7 @@ def test_pac_record() -> None:
pac_record = PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=cap,
metadata=b'',
metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
)
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
@@ -142,7 +143,7 @@ def test_ASE_Config_QOS() -> None:
def test_ASE_Enable() -> None:
operation = ASE_Enable(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
metadata=[b'', b''],
)
basic_check(operation)
@@ -151,7 +152,7 @@ def test_ASE_Enable() -> None:
def test_ASE_Update_Metadata() -> None:
operation = ASE_Update_Metadata(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
metadata=[b'', b''],
)
basic_check(operation)

View File

@@ -276,34 +276,6 @@ async def test_legacy_advertising():
assert not device.is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_legacy_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
# Start advertising
await device.start_advertising()
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
await async_barrier()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
@@ -318,6 +290,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
0x0001,
BT_LE_TRANSPORT,
peer_address,
None,
None,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
@@ -367,6 +341,8 @@ async def test_extended_advertising_connection(own_address_type):
0x0001,
BT_LE_TRANSPORT,
peer_address,
None,
None,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
@@ -407,6 +383,8 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
0x0001,
BT_LE_TRANSPORT,
peer_address,
None,
None,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)

View File

@@ -879,6 +879,57 @@ async def test_unsubscribe():
mock1.assert_called_once_with(ANY, False, False)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_discover_all():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
descriptor1 = Descriptor('2902', 'READABLE,WRITEABLE')
descriptor2 = Descriptor('AAAA', 'READABLE,WRITEABLE')
characteristic2 = Characteristic(
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
descriptors=[descriptor1, descriptor2],
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic1, characteristic2],
)
service2 = Service('1111', [])
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_all()
assert len(peer.gatt_client.services) == 3
# service 1800 gets added automatically
assert peer.gatt_client.services[0].uuid == UUID('1800')
assert peer.gatt_client.services[1].uuid == service1.uuid
assert peer.gatt_client.services[2].uuid == service2.uuid
s = peer.get_services_by_uuid(service1.uuid)
assert len(s) == 1
assert len(s[0].characteristics) == 2
c = peer.get_characteristics_by_uuid(uuid=characteristic2.uuid, service=s[0])
assert len(c) == 1
assert len(c[0].descriptors) == 2
s = peer.get_services_by_uuid(service2.uuid)
assert len(s) == 1
assert len(s[0].characteristics) == 0
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
@@ -1146,6 +1197,56 @@ def test_get_attribute_group():
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_characteristics_by_uuid():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'1234',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
characteristic2 = Characteristic(
'5678',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
service1 = Service(
'ABCD',
[characteristic1, characteristic2],
)
service2 = Service(
'FFFF',
[characteristic1],
)
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0
s = peer.get_services_by_uuid(uuid=UUID('ABCD'))
assert len(s) == 1
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=s[0])
assert len(s) == 1
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())

39
tests/le_audio_test.py Normal file
View File

@@ -0,0 +1,39 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.profiles import le_audio
def test_parse_metadata():
metadata = le_audio.Metadata(
entries=[
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PROGRAM_INFO,
data=b'',
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
data=bytes([0, 0]),
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
data=bytes([1, 2]),
),
]
)
assert le_audio.Metadata.from_bytes(bytes(metadata)) == metadata

132
tests/mcp_test.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import dataclasses
import pytest
import pytest_asyncio
import struct
import logging
from bumble import device
from bumble.profiles import mcp
from tests.test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
TIMEOUT = 0.1
@dataclasses.dataclass
class GmcsContext:
devices: TwoDevices
client: mcp.GenericMediaControlServiceProxy
server: mcp.GenericMediaControlService
# -----------------------------------------------------------------------------
@pytest_asyncio.fixture
async def gmcs_context():
devices = TwoDevices()
server = mcp.GenericMediaControlService()
devices[0].add_service(server)
await devices.setup_connection()
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
client = await peer.discover_service_and_create_proxy(
mcp.GenericMediaControlServiceProxy
)
await client.subscribe_characteristics()
return GmcsContext(devices=devices, server=server, client=client)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_media_state(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('media_state', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.media_state_characteristic,
value=bytes([mcp.MediaState.PLAYING]),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == mcp.MediaState.PLAYING
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_title(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_title', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_title_characteristic,
value="My Song".encode(),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == "My Song"
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_duration(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_duration', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_duration_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_position(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_position', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_position_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_media_control_point(gmcs_context):
assert (
await asyncio.wait_for(
gmcs_context.client.write_control_point(mcp.MediaControlPointOpcode.PAUSE),
TIMEOUT,
)
) == mcp.MediaControlPointResultCode.SUCCESS

View File

@@ -50,4 +50,6 @@ Example:
NOTE: to get a local build of the Bumble package, use `inv build`, the built `.whl` file can be found in the `dist` directory.
Make a copy of the built `.whl` file in the `web` directory.
Make a copy of the built `.whl` file in the `web` directory.
Tip: During web developement, disable caching. [Chrome](https://stackoverflow.com/a/7000899]) / [Firefiox](https://stackoverflow.com/a/289771)

1
web/favicon.ico Symbolic link
View File

@@ -0,0 +1 @@
../docs/images/favicon.ico

View File

@@ -15,12 +15,21 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import pyee
from bumble.device import Device
from bumble.hci import HCI_Reset_Command
# -----------------------------------------------------------------------------
class Scanner:
class Scanner(pyee.EventEmitter):
"""
Scanner web app
Emitted events:
update: Emit when new `ScanEntry` are available.
"""
class ScanEntry:
def __init__(self, advertisement):
self.address = advertisement.address.to_string(False)
@@ -39,13 +48,12 @@ class Scanner:
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
self.scan_entries = {}
self.listeners = {}
self.device.on('advertisement', self.on_advertisement)
async def start(self):
print('### Starting Scanner')
self.scan_entries = {}
self.emit_update()
self.emit('update', self.scan_entries)
await self.device.power_on()
await self.device.start_scanning()
print('### Scanner started')
@@ -56,16 +64,9 @@ class Scanner:
await self.device.power_off()
print('### Scanner stopped')
def emit_update(self):
if listener := self.listeners.get('update'):
listener(list(self.scan_entries.values()))
def on(self, event_name, listener):
self.listeners[event_name] = listener
def on_advertisement(self, advertisement):
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
self.emit_update()
self.emit('update', self.scan_entries)
# -----------------------------------------------------------------------------