Compare commits

...

13 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod a0b5606047 don't user a parser for a usb source 2024-08-11 20:57:45 -07:00
Gilles Boccon-Gibod 4433184048 Merge pull request #522 from google/gbg/rpa2
add basic RPA support
2024-08-06 10:35:39 -07:00
Gilles Boccon-Gibod 312fc8db36 support controller-generated rpa 2024-08-05 08:59:05 -07:00
Gilles Boccon-Gibod 615691ec81 add basic RPA support 2024-08-01 15:37:11 -07:00
zxzxwu ae8b83f294 Merge pull request #521 from zxzxwu/bap
Add Metadata LTV serializer and adapt Unicast
2024-07-31 11:36:46 +08:00
Josh Wu 4a8e21f4db Add Metadata LTV serializer and adapt Unicast 2024-07-31 01:20:28 +08:00
zxzxwu 3462e7c437 Merge pull request #439 from zxzxwu/mcp
Media Control Service Client implementation
2024-07-24 23:45:00 +08:00
Josh Wu 0f2e5239ad MCP constants and Client implementation 2024-07-24 22:57:26 +08:00
Gilles Boccon-Gibod ee48cdc63f Merge pull request #517 from AlanRosenthal/scanner_pyee
Update scanner.py to use pyee.EventEmitter
2024-07-18 12:53:00 -07:00
Gilles Boccon-Gibod 1c278bec93 Merge pull request #518 from google/gbg/usb-queue
USB: better packet queue logic
2024-07-18 12:51:00 -07:00
Alan Rosenthal 85d79fa914 Update scanner.py to use pyee.EventEmitter 2024-07-17 16:53:50 -04:00
zxzxwu 142bdce94a Merge pull request #515 from zxzxwu/unix
Add UNIX socket transport
2024-07-17 16:04:38 +08:00
Josh Wu 881a5a64b5 Add UNIX socket transport 2024-07-17 00:41:04 +08:00
20 changed files with 1198 additions and 102 deletions
+99 -25
View File
@@ -182,6 +182,7 @@ from .core import (
BaseBumbleError, BaseBumbleError,
ConnectionParameterUpdateError, ConnectionParameterUpdateError,
CommandTimeoutError, CommandTimeoutError,
ConnectionParameters,
ConnectionPHY, ConnectionPHY,
InvalidArgumentError, InvalidArgumentError,
InvalidOperationError, InvalidOperationError,
@@ -259,8 +260,9 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
DEVICE_DEFAULT_ADVERTISING_TX_POWER = ( DEVICE_DEFAULT_ADVERTISING_TX_POWER = (
HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE HCI_LE_Set_Extended_Advertising_Parameters_Command.TX_POWER_NO_PREFERENCE
) )
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0 DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_SKIP = 0
DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0 DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT = 5.0
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -1303,6 +1305,7 @@ class Connection(CompositeEventEmitter):
handle: int handle: int
transport: int transport: int
self_address: Address self_address: Address
self_resolvable_address: Optional[Address]
peer_address: Address peer_address: Address
peer_resolvable_address: Optional[Address] peer_resolvable_address: Optional[Address]
peer_le_features: Optional[LeFeatureMask] peer_le_features: Optional[LeFeatureMask]
@@ -1350,6 +1353,7 @@ class Connection(CompositeEventEmitter):
handle, handle,
transport, transport,
self_address, self_address,
self_resolvable_address,
peer_address, peer_address,
peer_resolvable_address, peer_resolvable_address,
role, role,
@@ -1361,6 +1365,7 @@ class Connection(CompositeEventEmitter):
self.handle = handle self.handle = handle
self.transport = transport self.transport = transport
self.self_address = self_address self.self_address = self_address
self.self_resolvable_address = self_resolvable_address
self.peer_address = peer_address self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only self.peer_name = None # Classic only
@@ -1394,6 +1399,7 @@ class Connection(CompositeEventEmitter):
None, None,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
device.public_address, device.public_address,
None,
peer_address, peer_address,
None, None,
role, role,
@@ -1552,7 +1558,9 @@ class Connection(CompositeEventEmitter):
f'Connection(handle=0x{self.handle:04X}, ' f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, ' f'role={self.role_name}, '
f'self_address={self.self_address}, ' f'self_address={self.self_address}, '
f'peer_address={self.peer_address})' f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
) )
@@ -1567,8 +1575,9 @@ class DeviceConfiguration:
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
le_enabled: bool = True le_enabled: bool = True
# LE host enable 2nd parameter
le_simultaneous_enabled: bool = False le_simultaneous_enabled: bool = False
le_privacy_enabled: bool = False
le_rpa_timeout: int = DEVICE_DEFAULT_LE_RPA_TIMEOUT
classic_enabled: bool = False classic_enabled: bool = False
classic_sc_enabled: bool = True classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True classic_ssp_enabled: bool = True
@@ -1584,6 +1593,7 @@ class DeviceConfiguration:
irk: bytes = bytes(16) # This really must be changed for any level of security irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None keystore: Optional[str] = None
address_resolution_offload: bool = False address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False cis_enabled: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
@@ -1736,8 +1746,9 @@ device_host_event_handlers: List[str] = []
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Device(CompositeEventEmitter): class Device(CompositeEventEmitter):
# Incomplete list of fields. # Incomplete list of fields.
random_address: Address random_address: Address # Random address that may change with RPA
public_address: Address public_address: Address # Public address (obtained from the controller)
static_address: Address # Random address that can be set but does not change
classic_enabled: bool classic_enabled: bool
name: str name: str
class_of_device: int class_of_device: int
@@ -1867,15 +1878,19 @@ class Device(CompositeEventEmitter):
config = config or DeviceConfiguration() config = config or DeviceConfiguration()
self.config = config self.config = config
self.public_address = Address('00:00:00:00:00:00')
self.name = config.name self.name = config.name
self.public_address = Address.ANY
self.random_address = config.address self.random_address = config.address
self.static_address = config.address
self.class_of_device = config.class_of_device self.class_of_device = config.class_of_device
self.keystore = None self.keystore = None
self.irk = config.irk self.irk = config.irk
self.le_enabled = config.le_enabled self.le_enabled = config.le_enabled
self.classic_enabled = config.classic_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.le_rpa_periodic_update_task: Optional[asyncio.Task] = None
self.classic_enabled = config.classic_enabled
self.cis_enabled = config.cis_enabled self.cis_enabled = config.cis_enabled
self.classic_sc_enabled = config.classic_sc_enabled self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled self.classic_ssp_enabled = config.classic_ssp_enabled
@@ -1884,6 +1899,7 @@ class Device(CompositeEventEmitter):
self.connectable = config.connectable self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload self.address_resolution_offload = config.address_resolution_offload
self.address_generation_offload = config.address_generation_offload
# Extended advertising. # Extended advertising.
self.extended_advertising_sets: Dict[int, AdvertisingSet] = {} self.extended_advertising_sets: Dict[int, AdvertisingSet] = {}
@@ -1939,6 +1955,7 @@ class Device(CompositeEventEmitter):
if isinstance(address, str): if isinstance(address, str):
address = Address(address) address = Address(address)
self.random_address = address self.random_address = address
self.static_address = address
# Setup SMP # Setup SMP
self.smp_manager = smp.Manager( self.smp_manager = smp.Manager(
@@ -2170,6 +2187,16 @@ class Device(CompositeEventEmitter):
) )
if self.le_enabled: if self.le_enabled:
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
logger.info(f'Initial RPA: {self.random_address}')
if self.le_rpa_timeout > 0:
# Start a task to periodically generate a new RPA
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
# Set the controller address # Set the controller address
if self.random_address == Address.ANY_RANDOM: if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller # Try to use an address generated at random by the controller
@@ -2249,9 +2276,45 @@ class Device(CompositeEventEmitter):
async def power_off(self) -> None: async def power_off(self) -> None:
if self.powered_on: if self.powered_on:
if self.le_rpa_periodic_update_task:
self.le_rpa_periodic_update_task.cancel()
await self.host.flush() await self.host.flush()
self.powered_on = False self.powered_on = False
async def update_rpa(self) -> bool:
"""
Try to update the RPA.
Returns:
True if the RPA was updated, False if it could not be updated.
"""
# Check if this is a good time to rotate the address
if self.is_advertising or self.is_scanning or self.is_le_connecting:
logger.debug('skipping RPA update')
return False
random_address = Address.generate_private_address(self.irk)
response = await self.send_command(
HCI_LE_Set_Random_Address_Command(random_address=self.random_address)
)
if response.return_parameters == HCI_SUCCESS:
logger.info(f'new RPA: {random_address}')
self.random_address = random_address
return True
else:
logger.warning(f'failed to set RPA: {response.return_parameters}')
return False
async def _run_rpa_periodic_update(self) -> None:
"""Update the RPA periodically"""
while self.le_rpa_timeout != 0:
await asyncio.sleep(self.le_rpa_timeout)
if not self.update_rpa():
logger.debug("periodic RPA update failed")
async def refresh_resolving_list(self) -> None: async def refresh_resolving_list(self) -> None:
assert self.keystore is not None assert self.keystore is not None
@@ -2259,7 +2322,7 @@ class Device(CompositeEventEmitter):
# Create a host-side address resolver # Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys) self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload: if self.address_resolution_offload or self.address_generation_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) await self.send_command(HCI_LE_Clear_Resolving_List_Command())
# Add an empty entry for non-directed address generation. # Add an empty entry for non-directed address generation.
@@ -4104,12 +4167,14 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
def on_connection( def on_connection(
self, self,
connection_handle, connection_handle: int,
transport, transport: int,
peer_address, peer_address: Address,
role, self_resolvable_address: Optional[Address],
connection_parameters, peer_resolvable_address: Optional[Address],
): role: int,
connection_parameters: ConnectionParameters,
) -> None:
logger.debug( logger.debug(
f'*** Connection: [0x{connection_handle:04X}] ' f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}' f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
@@ -4130,15 +4195,15 @@ class Device(CompositeEventEmitter):
return return
# Resolve the peer address if we can if peer_resolvable_address is None:
peer_resolvable_address = None # Resolve the peer address if we can
if self.address_resolver: if self.address_resolver:
if peer_address.is_resolvable: if peer_address.is_resolvable:
resolved_address = self.address_resolver.resolve(peer_address) resolved_address = self.address_resolver.resolve(peer_address)
if resolved_address is not None: if resolved_address is not None:
logger.debug(f'*** Address resolved as {resolved_address}') logger.debug(f'*** Address resolved as {resolved_address}')
peer_resolvable_address = peer_address peer_resolvable_address = peer_address
peer_address = resolved_address peer_address = resolved_address
self_address = None self_address = None
if role == HCI_CENTRAL_ROLE: if role == HCI_CENTRAL_ROLE:
@@ -4169,12 +4234,19 @@ class Device(CompositeEventEmitter):
else self.random_address else self.random_address
) )
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if peer_resolvable_address == Address.ANY_RANDOM:
peer_resolvable_address = None
# Create a connection. # Create a connection.
connection = Connection( connection = Connection(
self, self,
connection_handle, connection_handle,
transport, transport,
self_address, self_address,
self_resolvable_address,
peer_address, peer_address,
peer_resolvable_address, peer_resolvable_address,
role, role,
@@ -4185,9 +4257,10 @@ class Device(CompositeEventEmitter):
if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser: if role == HCI_PERIPHERAL_ROLE and self.legacy_advertiser:
if self.legacy_advertiser.auto_restart: if self.legacy_advertiser.auto_restart:
advertiser = self.legacy_advertiser
connection.once( connection.once(
'disconnection', 'disconnection',
lambda _: self.abort_on('flush', self.legacy_advertiser.start()), lambda _: self.abort_on('flush', advertiser.start()),
) )
else: else:
self.legacy_advertiser = None self.legacy_advertiser = None
@@ -4871,5 +4944,6 @@ class Device(CompositeEventEmitter):
return ( return (
f'Device(name="{self.name}", ' f'Device(name="{self.name}", '
f'random_address="{self.random_address}", ' f'random_address="{self.random_address}", '
f'public_address="{self.public_address}")' f'public_address="{self.public_address}", '
f'static_address="{self.static_address}")'
) )
+10 -3
View File
@@ -1839,6 +1839,12 @@ class Address:
data, offset, Address.PUBLIC_DEVICE_ADDRESS data, offset, Address.PUBLIC_DEVICE_ADDRESS
) )
@staticmethod
def parse_random_address(data, offset):
return Address.parse_address_with_type(
data, offset, Address.RANDOM_DEVICE_ADDRESS
)
@staticmethod @staticmethod
def parse_address_with_type(data, offset, address_type): def parse_address_with_type(data, offset, address_type):
return offset + 6, Address(data[offset : offset + 6], address_type) return offset + 6, Address(data[offset : offset + 6], address_type)
@@ -1965,7 +1971,8 @@ class Address:
def __eq__(self, other): def __eq__(self, other):
return ( return (
self.address_bytes == other.address_bytes isinstance(other, Address)
and self.address_bytes == other.address_bytes
and self.is_public == other.is_public and self.is_public == other.is_public
) )
@@ -5178,8 +5185,8 @@ class HCI_LE_Data_Length_Change_Event(HCI_LE_Meta_Event):
), ),
('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type), ('peer_address', Address.parse_address_preceded_by_type),
('local_resolvable_private_address', Address.parse_address), ('local_resolvable_private_address', Address.parse_random_address),
('peer_resolvable_private_address', Address.parse_address), ('peer_resolvable_private_address', Address.parse_random_address),
('connection_interval', 2), ('connection_interval', 2),
('peripheral_latency', 2), ('peripheral_latency', 2),
('supervision_timeout', 2), ('supervision_timeout', 2),
+4
View File
@@ -772,6 +772,8 @@ class Host(AbortableEventEmitter):
event.connection_handle, event.connection_handle,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
event.peer_address, event.peer_address,
getattr(event, 'local_resolvable_private_address', None),
getattr(event, 'peer_resolvable_private_address', None),
event.role, event.role,
connection_parameters, connection_parameters,
) )
@@ -817,6 +819,8 @@ class Host(AbortableEventEmitter):
event.bd_addr, event.bd_addr,
None, None,
None, None,
None,
None,
) )
else: else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
+13 -10
View File
@@ -685,10 +685,11 @@ class CodecSpecificConfiguration:
@dataclasses.dataclass @dataclasses.dataclass
class PacRecord: class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
# TODO: Parse Metadata metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
metadata: bytes = b''
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> PacRecord: def from_bytes(cls, data: bytes) -> PacRecord:
@@ -701,7 +702,8 @@ class PacRecord:
] ]
offset += codec_specific_capabilities_size offset += codec_specific_capabilities_size
metadata_size = data[offset] metadata_size = data[offset]
metadata = data[offset : offset + metadata_size] offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
@@ -719,12 +721,13 @@ class PacRecord:
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities) capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return ( return (
bytes(self.coding_format) bytes(self.coding_format)
+ bytes([len(capabilities_bytes)]) + bytes([len(capabilities_bytes)])
+ capabilities_bytes + capabilities_bytes
+ bytes([len(self.metadata)]) + bytes([len(metadata_bytes)])
+ self.metadata + metadata_bytes
) )
@@ -940,8 +943,7 @@ class AseStateMachine(gatt.Characteristic):
presentation_delay = 0 presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State # Additional parameters in ENABLING, STREAMING, DISABLING State
# TODO: Parse this metadata = le_audio.Metadata()
metadata = b''
def __init__( def __init__(
self, self,
@@ -1088,7 +1090,7 @@ class AseStateMachine(gatt.Characteristic):
AseReasonCode.NONE, AseReasonCode.NONE,
) )
self.metadata = metadata self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@@ -1140,7 +1142,7 @@ class AseStateMachine(gatt.Characteristic):
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE, AseReasonCode.NONE,
) )
self.metadata = metadata self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
@@ -1217,8 +1219,9 @@ class AseStateMachine(gatt.Characteristic):
self.State.STREAMING, self.State.STREAMING,
self.State.DISABLING, self.State.DISABLING,
): ):
metadata_bytes = bytes(self.metadata)
additional_parameters = ( additional_parameters = (
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
) )
else: else:
additional_parameters = b'' additional_parameters = b''
+43 -9
View File
@@ -17,33 +17,67 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import List import struct
from typing import List, Type
from typing_extensions import Self from typing_extensions import Self
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class Metadata: class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
again outside the lib.
'''
class Tag(utils.OpenIntEnum):
# fmt: off
PREFERRED_AUDIO_CONTEXTS = 0x01
STREAMING_AUDIO_CONTEXTS = 0x02
PROGRAM_INFO = 0x03
LANGUAGE = 0x04
CCID_LIST = 0x05
PARENTAL_RATING = 0x06
PROGRAM_INFO_URI = 0x07
AUDIO_ACTIVE_STATE = 0x08
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
ASSISTED_LISTENING_STREAM = 0x0A
BROADCAST_NAME = 0x0B
EXTENDED_METADATA = 0xFE
VENDOR_SPECIFIC = 0xFF
@dataclasses.dataclass @dataclasses.dataclass
class Entry: class Entry:
tag: int tag: Metadata.Tag
data: bytes data: bytes
entries: List[Entry] @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
def __bytes__(self) -> bytes:
return bytes([len(self.data) + 1, self.tag]) + self.data
entries: List[Entry] = dataclasses.field(default_factory=list)
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = [] entries = []
offset = 0 offset = 0
length = len(data) length = len(data)
while length >= 2: while offset < length:
entry_length = data[offset] entry_length = data[offset]
entry_tag = data[offset + 1] offset += 1
entry_data = data[offset + 2 : offset + 2 + entry_length - 1] entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
entries.append(cls.Entry(entry_tag, entry_data))
length -= entry_length
offset += entry_length offset += entry_length
return cls(entries) return cls(entries)
def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries])
+448
View File
@@ -0,0 +1,448 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class PlayingOrder(utils.OpenIntEnum):
'''See Media Control Service 3.15. Playing Order.'''
SINGLE_ONCE = 0x01
SINGLE_REPEAT = 0x02
IN_ORDER_ONCE = 0x03
IN_ORDER_REPEAT = 0x04
OLDEST_ONCE = 0x05
OLDEST_REPEAT = 0x06
NEWEST_ONCE = 0x07
NEWEST_REPEAT = 0x08
SHUFFLE_ONCE = 0x09
SHUFFLE_REPEAT = 0x0A
class PlayingOrderSupported(enum.IntFlag):
'''See Media Control Service 3.16. Playing Orders Supported.'''
SINGLE_ONCE = 0x0001
SINGLE_REPEAT = 0x0002
IN_ORDER_ONCE = 0x0004
IN_ORDER_REPEAT = 0x0008
OLDEST_ONCE = 0x0010
OLDEST_REPEAT = 0x0020
NEWEST_ONCE = 0x0040
NEWEST_REPEAT = 0x0080
SHUFFLE_ONCE = 0x0100
SHUFFLE_REPEAT = 0x0200
class MediaState(utils.OpenIntEnum):
'''See Media Control Service 3.17. Media State.'''
INACTIVE = 0x00
PLAYING = 0x01
PAUSED = 0x02
SEEKING = 0x03
class MediaControlPointOpcode(utils.OpenIntEnum):
'''See Media Control Service 3.18. Media Control Point.'''
PLAY = 0x01
PAUSE = 0x02
FAST_REWIND = 0x03
FAST_FORWARD = 0x04
STOP = 0x05
MOVE_RELATIVE = 0x10
PREVIOUS_SEGMENT = 0x20
NEXT_SEGMENT = 0x21
FIRST_SEGMENT = 0x22
LAST_SEGMENT = 0x23
GOTO_SEGMENT = 0x24
PREVIOUS_TRACK = 0x30
NEXT_TRACK = 0x31
FIRST_TRACK = 0x32
LAST_TRACK = 0x33
GOTO_TRACK = 0x34
PREVIOUS_GROUP = 0x40
NEXT_GROUP = 0x41
FIRST_GROUP = 0x42
LAST_GROUP = 0x43
GOTO_GROUP = 0x44
class MediaControlPointResultCode(enum.IntFlag):
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
SUCCESS = 0x01
OPCODE_NOT_SUPPORTED = 0x02
MEDIA_PLAYER_INACTIVE = 0x03
COMMAND_CANNOT_BE_COMPLETED = 0x04
class MediaControlPointOpcodeSupported(enum.IntFlag):
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
PLAY = 0x00000001
PAUSE = 0x00000002
FAST_REWIND = 0x00000004
FAST_FORWARD = 0x00000008
STOP = 0x00000010
MOVE_RELATIVE = 0x00000020
PREVIOUS_SEGMENT = 0x00000040
NEXT_SEGMENT = 0x00000080
FIRST_SEGMENT = 0x00000100
LAST_SEGMENT = 0x00000200
GOTO_SEGMENT = 0x00000400
PREVIOUS_TRACK = 0x00000800
NEXT_TRACK = 0x00001000
FIRST_TRACK = 0x00002000
LAST_TRACK = 0x00004000
GOTO_TRACK = 0x00008000
PREVIOUS_GROUP = 0x00010000
NEXT_GROUP = 0x00020000
FIRST_GROUP = 0x00040000
LAST_GROUP = 0x00080000
GOTO_GROUP = 0x00100000
class SearchControlPointItemType(utils.OpenIntEnum):
'''See Media Control Service 3.20. Search Control Point.'''
TRACK_NAME = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
GROUP_NAME = 0x04
EARLIEST_YEAR = 0x05
LATEST_YEAR = 0x06
GENRE = 0x07
ONLY_TRACKS = 0x08
ONLY_GROUPS = 0x09
class ObjectType(utils.OpenIntEnum):
'''See Media Control Service 4.4.1. Object Type field.'''
TASK = 0
GROUP = 1
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectId(int):
'''See Media Control Service 4.4.2. Object ID field.'''
@classmethod
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(int.from_bytes(data, byteorder='little', signed=False))
def __bytes__(self) -> bytes:
return self.to_bytes(6, 'little')
@dataclasses.dataclass
class GroupObjectType:
'''See Media Control Service 4.4. Group Object Type.'''
object_type: ObjectType
object_id: ObjectId
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(
object_type=ObjectType(data[0]),
object_id=ObjectId.create_from_bytes(data[1:]),
)
def __bytes__(self) -> bytes:
return bytes([self.object_type]) + bytes(self.object_id)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class MediaControlService(gatt.TemplateService):
'''Media Control Service server implementation, only for testing currently.'''
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player',
)
self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_title_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_duration_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_position_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_state_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_control_point_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.content_control_id_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
super().__init__(
[
self.media_player_name_characteristic,
self.track_changed_characteristic,
self.track_title_characteristic,
self.track_duration_characteristic,
self.track_position_characteristic,
self.media_state_characteristic,
self.media_control_point_characteristic,
self.media_control_point_opcodes_supported_characteristic,
self.content_control_id_characteristic,
]
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(
connection,
self.media_control_point_characteristic,
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
)
class GenericMediaControlService(MediaControlService):
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class MediaControlServiceProxy(
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
):
SERVICE_CLASS = MediaControlService
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
}
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
track_changed: Optional[gatt_client.CharacteristicProxy] = None
track_title: Optional[gatt_client.CharacteristicProxy] = None
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
media_state: Optional[gatt_client.CharacteristicProxy] = None
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
None
)
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
utils.CompositeEventEmitter.__init__(self)
self.service_proxy = service_proxy
self.lock = asyncio.Lock()
self.media_control_point_notifications = asyncio.Queue()
for field, uuid in self._CHARACTERISTICS.items():
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, field, characteristics[0])
async def subscribe_characteristics(self) -> None:
if self.media_control_point:
await self.media_control_point.subscribe(self._on_media_control_point)
if self.media_state:
await self.media_state.subscribe(self._on_media_state)
if self.track_changed:
await self.track_changed.subscribe(self._on_track_changed)
if self.track_title:
await self.track_title.subscribe(self._on_track_title)
if self.track_duration:
await self.track_duration.subscribe(self._on_track_duration)
if self.track_position:
await self.track_position.subscribe(self._on_track_position)
async def write_control_point(
self, opcode: MediaControlPointOpcode
) -> MediaControlPointResultCode:
'''Writes a Media Control Point Opcode to peer and waits for the notification.
The write operation will be executed when there isn't other pending commands.
Args:
opcode: opcode defined in `MediaControlPointOpcode`.
Returns:
Response code provided in `MediaControlPointResultCode`
Raises:
InvalidOperationError: Server does not have Media Control Point Characteristic.
InvalidStateError: Server replies a notification with mismatched opcode.
'''
if not self.media_control_point:
raise core.InvalidOperationError("Peer does not have media control point")
async with self.lock:
await self.media_control_point.write_value(
bytes([opcode]),
with_response=False,
)
(
response_opcode,
response_code,
) = await self.media_control_point_notifications.get()
if response_opcode != opcode:
raise core.InvalidStateError(
f"Expected {opcode} notification, but get {response_opcode}"
)
return MediaControlPointResultCode(response_code)
def _on_media_control_point(self, data: bytes) -> None:
self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None:
del data
self.emit('track_changed')
def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
SERVICE_CLASS = GenericMediaControlService
+6 -3
View File
@@ -767,8 +767,11 @@ class Session:
self.oob_data_flag = 0 if pairing_config.oob is None else 1 self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses # Set up addresses
self_address = connection.self_address self_address = connection.self_resolvable_address or connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address peer_address = connection.peer_resolvable_address or connection.peer_address
logger.debug(
f"pairing with self_address={self_address}, peer_address={peer_address}"
)
if self.is_initiator: if self.is_initiator:
self.ia = bytes(self_address) self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0 self.iat = 1 if self_address.is_random else 0
@@ -1076,9 +1079,9 @@ class Session:
def send_identity_address_command(self) -> None: def send_identity_address_command(self) -> None:
identity_address = { identity_address = {
None: self.connection.self_address, None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address, Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address, Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type] }[self.pairing_config.identity_address_type]
self.send_command( self.send_command(
SMP_Identity_Address_Information_Command( SMP_Identity_Address_Information_Command(
+6
View File
@@ -180,6 +180,12 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec)
if scheme == 'unix':
from .unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
raise TransportSpecError('unknown transport scheme') raise TransportSpecError('unknown transport scheme')
+24 -7
View File
@@ -248,26 +248,26 @@ class AsyncPipeSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ParserSource: class BaseSource:
""" """
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
terminated: asyncio.Future[None] terminated: asyncio.Future[None]
parser: PacketParser sink: Optional[TransportSink]
def __init__(self) -> None: def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
self.sink = None
def set_packet_sink(self, sink: TransportSink) -> None: def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink) self.sink = sink
def on_transport_lost(self) -> None: def on_transport_lost(self) -> None:
self.terminated.set_result(None) self.terminated.set_result(None)
if self.parser.sink: if self.sink:
if hasattr(self.parser.sink, 'on_transport_lost'): if hasattr(self.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost() self.sink.on_transport_lost()
async def wait_for_termination(self) -> None: async def wait_for_termination(self) -> None:
""" """
@@ -280,6 +280,23 @@ class ParserSource:
pass pass
# -----------------------------------------------------------------------------
class ParserSource(BaseSource):
"""
Base class for sources that use an HCI parser.
"""
parser: PacketParser
def __init__(self) -> None:
super().__init__()
self.parser = PacketParser()
def set_packet_sink(self, sink: TransportSink) -> None:
super().set_packet_sink(sink)
self.parser.set_packet_sink(sink)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
+56
View File
@@ -0,0 +1,56 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_unix_client_transport(spec: str) -> Transport:
'''Open a UNIX socket client transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
class UnixPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.on_transport_lost()
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
(
unix_transport,
packet_source,
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink)
+9 -3
View File
@@ -24,7 +24,7 @@ import platform
import usb1 import usb1
from bumble.transport.common import Transport, ParserSource, TransportInitError from bumble.transport.common import Transport, BaseSource, TransportInitError
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
@@ -208,7 +208,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError: except usb1.USBError:
logger.debug('OUT transfer likely already completed') logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, BaseSource):
def __init__(self, device, metadata, acl_in, events_in): def __init__(self, device, metadata, acl_in, events_in):
super().__init__() super().__init__()
self.device = device self.device = device
@@ -285,7 +285,13 @@ async def open_usb_transport(spec: str) -> Transport:
packet = await self.queue.get() packet = await self.queue.get()
except asyncio.CancelledError: except asyncio.CancelledError:
return return
self.parser.feed_data(packet) if self.sink:
try:
self.sink.on_packet(packet)
except Exception as error:
logger.exception(
color(f'!!! Exception in sink.on_packet: {error}', 'red')
)
def close(self): def close(self):
self.closed = True self.closed = True
+7
View File
@@ -0,0 +1,7 @@
{
"name": "Bumble",
"address": "F0:F1:F2:F3:F4:F5",
"keystore": "JsonKeyStore",
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
"le_privacy_enabled": true
}
+1
View File
@@ -3,5 +3,6 @@
"keystore": "JsonKeyStore", "keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA", "address": "F0:F1:F2:F3:F4:FA",
"class_of_device": 2376708, "class_of_device": 2376708,
"cis_enabled": true,
"advertising_interval": 100 "advertising_interval": 100
} }
+83
View File
@@ -0,0 +1,83 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble LEA Media Control Client</span>
</div>
</nav>
<br>
<div class="container">
<label class="form-label">Server Port</label>
<div class="input-group mb-3">
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
</div>
<button class="btn btn-primary" onclick="send_opcode(0x01)">Play</button>
<button class="btn btn-primary" onclick="send_opcode(0x02)">Pause</button>
<button class="btn btn-primary" onclick="send_opcode(0x03)">Fast Rewind</button>
<button class="btn btn-primary" onclick="send_opcode(0x04)">Fast Forward</button>
<button class="btn btn-primary" onclick="send_opcode(0x05)">Stop</button>
</br></br>
<button class="btn btn-primary" onclick="send_opcode(0x30)">Previous Track</button>
<button class="btn btn-primary" onclick="send_opcode(0x31)">Next Track</button>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let portInput = document.getElementById("port")
let log = document.getElementById("log")
let socket
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
log.textContent += 'OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = (event) => {
log.textContent += `<-- ${event.data}\n`
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
let jsonMessage = JSON.stringify(message)
log.textContent += `--> ${jsonMessage}\n`
socket.send(jsonMessage)
} else {
log.textContent += 'NOT CONNECTED\n'
}
}
function send_opcode(opcode) {
send({ 'opcode': opcode })
}
</script>
</div>
</body>
</html>
+196
View File
@@ -0,0 +1,196 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
import websockets
import json
from bumble.core import AdvertisingData
from bumble.device import (
Device,
AdvertisingParameters,
AdvertisingEventProperties,
Connection,
Peer,
)
from bumble.hci import (
CodecID,
CodingFormat,
OwnAddressType,
)
from bumble.profiles.bap import (
CodecSpecificCapabilities,
ContextType,
AudioLocation,
SupportedSamplingFrequency,
SupportedFrameDuration,
PacRecord,
PublishedAudioCapabilitiesService,
AudioStreamControlService,
UnicastServerAdvertisingData,
)
from bumble.profiles.mcp import (
MediaControlServiceProxy,
GenericMediaControlServiceProxy,
MediaState,
MediaControlPointOpcode,
)
from bumble.transport import open_transport_or_link
from typing import Optional
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print('Usage: run_mcp_client.py <config-file>' '<transport-spec-for-device>')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
# Add "placeholder" services to enable Android LEA features.
device.add_service(
PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED,
available_source_context=ContextType.PROHIBITED,
supported_sink_context=ContextType.MEDIA,
available_sink_context=ContextType.MEDIA,
sink_audio_locations=(
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
),
sink_pac=[
PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=CodecSpecificCapabilities(
supported_sampling_frequencies=(
SupportedSamplingFrequency.FREQ_16000
| SupportedSamplingFrequency.FREQ_32000
| SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=0,
max_octets_per_codec_frame=320,
supported_max_codec_frames_per_sdu=2,
),
),
],
)
)
device.add_service(AudioStreamControlService(device, sink_ase_id=[1]))
ws: Optional[websockets.WebSocketServerProtocol] = None
mcp: Optional[MediaControlServiceProxy] = None
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble LE Audio', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(UnicastServerAdvertisingData())
await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(
advertising_event_properties=AdvertisingEventProperties(),
own_address_type=OwnAddressType.RANDOM,
primary_advertising_interval_max=100,
primary_advertising_interval_min=100,
),
advertising_data=advertising_data,
auto_restart=True,
)
def on_media_state(media_state: MediaState) -> None:
if ws:
asyncio.create_task(
ws.send(json.dumps({'media_state': media_state.name}))
)
def on_track_title(title: str) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'title': title})))
def on_track_duration(duration: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'duration': duration})))
def on_track_position(position: int) -> None:
if ws:
asyncio.create_task(ws.send(json.dumps({'position': position})))
def on_connection(connection: Connection) -> None:
async def on_connection_async():
async with Peer(connection) as peer:
nonlocal mcp
mcp = peer.create_service_proxy(MediaControlServiceProxy)
if not mcp:
mcp = peer.create_service_proxy(GenericMediaControlServiceProxy)
mcp.on('media_state', on_media_state)
mcp.on('track_title', on_track_title)
mcp.on('track_duration', on_track_duration)
mcp.on('track_position', on_track_position)
await mcp.subscribe_characteristics()
connection.abort_on('disconnection', on_connection_async())
device.on('connection', on_connection)
async def serve(websocket: websockets.WebSocketServerProtocol, _path):
nonlocal ws
ws = websocket
async for message in websocket:
request = json.loads(message)
if mcp:
await mcp.write_control_point(
MediaControlPointOpcode(request['opcode'])
)
ws = None
await websockets.serve(serve, 'localhost', 8989)
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+4 -3
View File
@@ -48,6 +48,7 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy, PublishedAudioCapabilitiesServiceProxy,
) )
from bumble.profiles.le_audio import Metadata
from tests.test_utils import TwoDevices from tests.test_utils import TwoDevices
@@ -97,7 +98,7 @@ def test_pac_record() -> None:
pac_record = PacRecord( pac_record = PacRecord(
coding_format=CodingFormat(CodecID.LC3), coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=cap, codec_specific_capabilities=cap,
metadata=b'', metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
) )
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
@@ -142,7 +143,7 @@ def test_ASE_Config_QOS() -> None:
def test_ASE_Enable() -> None: def test_ASE_Enable() -> None:
operation = ASE_Enable( operation = ASE_Enable(
ase_id=[1, 2], ase_id=[1, 2],
metadata=[b'foo', b'bar'], metadata=[b'', b''],
) )
basic_check(operation) basic_check(operation)
@@ -151,7 +152,7 @@ def test_ASE_Enable() -> None:
def test_ASE_Update_Metadata() -> None: def test_ASE_Update_Metadata() -> None:
operation = ASE_Update_Metadata( operation = ASE_Update_Metadata(
ase_id=[1, 2], ase_id=[1, 2],
metadata=[b'foo', b'bar'], metadata=[b'', b''],
) )
basic_check(operation) basic_check(operation)
+6 -28
View File
@@ -276,34 +276,6 @@ async def test_legacy_advertising():
assert not device.is_advertising assert not device.is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_legacy_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
# Start advertising
await device.start_advertising()
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
await async_barrier()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.parametrize( @pytest.mark.parametrize(
'auto_restart,', 'auto_restart,',
@@ -318,6 +290,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
@@ -367,6 +341,8 @@ async def test_extended_advertising_connection(own_address_type):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
@@ -407,6 +383,8 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, peer_address,
None,
None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0), ConnectionParameters(0, 0, 0),
) )
+39
View File
@@ -0,0 +1,39 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.profiles import le_audio
def test_parse_metadata():
metadata = le_audio.Metadata(
entries=[
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PROGRAM_INFO,
data=b'',
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
data=bytes([0, 0]),
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
data=bytes([1, 2]),
),
]
)
assert le_audio.Metadata.from_bytes(bytes(metadata)) == metadata
+132
View File
@@ -0,0 +1,132 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import dataclasses
import pytest
import pytest_asyncio
import struct
import logging
from bumble import device
from bumble.profiles import mcp
from tests.test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
TIMEOUT = 0.1
@dataclasses.dataclass
class GmcsContext:
devices: TwoDevices
client: mcp.GenericMediaControlServiceProxy
server: mcp.GenericMediaControlService
# -----------------------------------------------------------------------------
@pytest_asyncio.fixture
async def gmcs_context():
devices = TwoDevices()
server = mcp.GenericMediaControlService()
devices[0].add_service(server)
await devices.setup_connection()
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
client = await peer.discover_service_and_create_proxy(
mcp.GenericMediaControlServiceProxy
)
await client.subscribe_characteristics()
return GmcsContext(devices=devices, server=server, client=client)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_media_state(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('media_state', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.media_state_characteristic,
value=bytes([mcp.MediaState.PLAYING]),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == mcp.MediaState.PLAYING
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_title(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_title', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_title_characteristic,
value="My Song".encode(),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == "My Song"
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_duration(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_duration', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_duration_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_update_track_position(gmcs_context):
state = asyncio.Queue()
gmcs_context.client.on('track_position', state.put_nowait)
await gmcs_context.devices[0].notify_subscribers(
gmcs_context.server.track_position_characteristic,
value=struct.pack("<i", 1000),
)
assert (await asyncio.wait_for(state.get(), TIMEOUT)) == 1000
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_media_control_point(gmcs_context):
assert (
await asyncio.wait_for(
gmcs_context.client.write_control_point(mcp.MediaControlPointOpcode.PAUSE),
TIMEOUT,
)
) == mcp.MediaControlPointResultCode.SUCCESS
+12 -11
View File
@@ -15,12 +15,21 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import pyee
from bumble.device import Device from bumble.device import Device
from bumble.hci import HCI_Reset_Command from bumble.hci import HCI_Reset_Command
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Scanner: class Scanner(pyee.EventEmitter):
"""
Scanner web app
Emitted events:
update: Emit when new `ScanEntry` are available.
"""
class ScanEntry: class ScanEntry:
def __init__(self, advertisement): def __init__(self, advertisement):
self.address = advertisement.address.to_string(False) self.address = advertisement.address.to_string(False)
@@ -39,13 +48,12 @@ class Scanner:
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
) )
self.scan_entries = {} self.scan_entries = {}
self.listeners = {}
self.device.on('advertisement', self.on_advertisement) self.device.on('advertisement', self.on_advertisement)
async def start(self): async def start(self):
print('### Starting Scanner') print('### Starting Scanner')
self.scan_entries = {} self.scan_entries = {}
self.emit_update() self.emit('update', self.scan_entries)
await self.device.power_on() await self.device.power_on()
await self.device.start_scanning() await self.device.start_scanning()
print('### Scanner started') print('### Scanner started')
@@ -56,16 +64,9 @@ class Scanner:
await self.device.power_off() await self.device.power_off()
print('### Scanner stopped') print('### Scanner stopped')
def emit_update(self):
if listener := self.listeners.get('update'):
listener(list(self.scan_entries.values()))
def on(self, event_name, listener):
self.listeners[event_name] = listener
def on_advertisement(self, advertisement): def on_advertisement(self, advertisement):
self.scan_entries[advertisement.address] = self.ScanEntry(advertisement) self.scan_entries[advertisement.address] = self.ScanEntry(advertisement)
self.emit_update() self.emit('update', self.scan_entries)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------