Compare commits

...

28 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
2478d45673 more windows compat fixes 2023-09-12 14:52:42 -07:00
Gilles Boccon-Gibod
1bc7d94111 windows NamedTemporaryFile compatibility 2023-09-12 14:33:12 -07:00
Gilles Boccon-Gibod
6432414cd5 run tests on windows and mac in addition to linux 2023-09-12 13:50:15 -07:00
Gilles Boccon-Gibod
179064ba15 run pre-commit tests with all supported Python versions 2023-09-12 13:42:33 -07:00
William Escande
783b2d70a5 Add connection parameter update from peripheral 2023-09-12 11:08:04 -07:00
zxzxwu
80824f3fc1 Merge pull request #280 from zxzxwu/device_typing
Add terminated to TransportSource protocol
2023-09-12 20:46:35 +08:00
Gilles Boccon-Gibod
56139c622f Merge pull request #258 from mogenson/vsc_tx_power
Add support for Zephyr HCI VSC set TX power command
2023-09-11 21:34:11 -07:00
Michael Mogenson
da02f6a39b Add HCI Zephyr vendor commands to read and write TX power
Create platforms/zephyr/hci.py with definitions of vendor HCI commands
to read and write TX power.

Add documentation for how to prepare an nRF52840 dongle with a Zephyr
HCI USB firmware application that includes dynamic TX power support and
how to send a write TX power vendor HCI command from Bumble.
2023-09-11 10:06:10 -04:00
Josh Wu
548d5597c0 Transport: Add termination protocol signature 2023-09-11 14:36:40 +08:00
zxzxwu
7fd65d2412 Merge pull request #279 from zxzxwu/typo
Fix typo
2023-09-11 03:02:11 +08:00
Josh Wu
05a54a4af9 Fix typo 2023-09-10 20:32:58 +08:00
Gilles Boccon-Gibod
1e00c8f456 Merge pull request #276 from google/gbg/add-zephyr-zip-to-docs
add zephyr binary to docs
2023-09-08 18:07:15 -07:00
Gilles Boccon-Gibod
90d165aa01 add zephyr binary 2023-09-08 14:17:15 -07:00
zxzxwu
01603ca9e4 Merge pull request #271 from zxzxwu/device_typing
Typing transport and relateds
2023-09-09 00:55:59 +08:00
Gilles Boccon-Gibod
a1b6eb61f2 Merge pull request #269 from google/gbg/android_vendor_hci
add support for vendor HCI commands and events
2023-09-08 08:50:49 -07:00
zxzxwu
25f300d3ec Merge pull request #270 from zxzxwu/typo
Fix typos
2023-09-08 17:32:33 +08:00
Josh Wu
41fe63df06 Fix typos 2023-09-08 16:30:06 +08:00
Josh Wu
b312170d5f Typing transport 2023-09-08 15:27:01 +08:00
David Duarte
cf7f2e8f44 Make platformdirs import lazy
platformdirs is not available in Android
2023-09-07 21:13:29 -07:00
Gilles Boccon-Gibod
d292083ed1 Merge pull request #272 from zxzxwu/gfp
Bring HfpProtocol back
2023-09-07 13:03:36 -07:00
Gilles Boccon-Gibod
9b11142b45 Merge pull request #267 from google/gbg/rfcomm-with-uuid
rfcomm with UUID
2023-09-07 13:01:56 -07:00
Hui Peng
acdbc4d7b9 Raise an exception when an L2cap connection fails 2023-09-07 19:24:38 +02:00
Josh Wu
838d10a09d Add HFP tests 2023-09-07 23:20:16 +08:00
Josh Wu
3852aa056b Bring HfpProtocol back 2023-09-07 23:20:09 +08:00
Gilles Boccon-Gibod
ae77e4528f add support for vendor HCI commands and events 2023-09-06 20:00:15 -07:00
Gilles Boccon-Gibod
8be9f4cb0e add doc and fix types 2023-09-06 17:05:30 -07:00
Gilles Boccon-Gibod
1ea12b1bf7 rebase 2023-09-06 17:05:24 -07:00
Gilles Boccon-Gibod
65e6d68355 add tcp server 2023-09-06 16:49:21 -07:00
44 changed files with 1303 additions and 375 deletions

View File

@@ -14,6 +14,10 @@ jobs:
check: check:
name: Check Code name: Check Code
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git

View File

@@ -12,10 +12,10 @@ permissions:
jobs: jobs:
build: build:
runs-on: ${{ matrix.os }}
runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false fail-fast: false
@@ -41,6 +41,7 @@ jobs:
run: | run: |
inv build inv build
inv build.mkdocs inv build.mkdocs
build-rust: build-rust:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:

View File

@@ -102,9 +102,21 @@ class SnoopPacketReader:
default='h4', default='h4',
help='Format of the input file', help='Format of the input file',
) )
@click.option(
'--vendors',
type=click.Choice(['android', 'zephyr']),
multiple=True,
help='Support vendor-specific commands (list one or more)',
)
@click.argument('filename') @click.argument('filename')
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def main(format, filename): def main(format, vendors, filename):
for vendor in vendors:
if vendor == 'android':
import bumble.vendor.android.hci
elif vendor == 'zephyr':
import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
if format == 'h4': if format == 'h4':
packet_reader = PacketReader(input) packet_reader = PacketReader(input)
@@ -124,7 +136,6 @@ def main(format, filename):
if packet is None: if packet is None:
break break
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction) tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
except Exception as error: except Exception as error:
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))

View File

@@ -56,7 +56,7 @@ body, h1, h2, h3, h4, h5, h6 {
border-radius: 4px; border-radius: 4px;
padding: 4px; padding: 4px;
margin: 6px; margin: 6px;
margin-left: 0px; margin-left: 0;
} }
th, td { th, td {
@@ -65,7 +65,7 @@ th, td {
} }
.properties td:nth-child(even) { .properties td:nth-child(even) {
background-color: #D6EEEE; background-color: #d6eeee;
font-family: monospace; font-family: monospace;
} }

View File

@@ -2,7 +2,7 @@
<html> <html>
<head> <head>
<title>Bumble Speaker</title> <title>Bumble Speaker</title>
<script type="text/javascript" src="speaker.js"></script> <script src="speaker.js"></script>
<link rel="stylesheet" href="speaker.css"> <link rel="stylesheet" href="speaker.css">
</head> </head>
<body> <body>

View File

@@ -15,6 +15,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import asyncio import asyncio
import itertools import itertools
@@ -58,8 +60,10 @@ from bumble.hci import (
HCI_Packet, HCI_Packet,
HCI_Role_Change_Event, HCI_Role_Change_Event,
) )
from typing import Optional, Union, Dict from typing import Optional, Union, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -104,7 +108,7 @@ class Controller:
self, self,
name, name,
host_source=None, host_source=None,
host_sink=None, host_sink: Optional[TransportSink] = None,
link=None, link=None,
public_address: Optional[Union[bytes, str, Address]] = None, public_address: Optional[Union[bytes, str, Address]] = None,
): ):

View File

@@ -78,7 +78,13 @@ def get_dict_key_by_value(dictionary, value):
class BaseError(Exception): class BaseError(Exception):
"""Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''): def __init__(
self,
error_code: int | None,
error_namespace: str = '',
error_name: str = '',
details: str = '',
):
super().__init__() super().__init__()
self.error_code = error_code self.error_code = error_code
self.error_namespace = error_namespace self.error_namespace = error_namespace
@@ -90,12 +96,14 @@ class BaseError(Exception):
namespace = f'{self.error_namespace}/' namespace = f'{self.error_namespace}/'
else: else:
namespace = '' namespace = ''
if self.error_name: error_text = {
name = f'{self.error_name} [0x{self.error_code:X}]' (True, True): f'{self.error_name} [0x{self.error_code:X}]',
else: (True, False): self.error_name,
name = f'0x{self.error_code:X}' (False, True): f'0x{self.error_code:X}',
(False, False): '',
}[(self.error_name != '', self.error_code is not None)]
return f'{type(self).__name__}({namespace}{name})' return f'{type(self).__name__}({namespace}{error_text})'
class ProtocolError(BaseError): class ProtocolError(BaseError):
@@ -134,6 +142,10 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
self.peer_address = peer_address self.peer_address = peer_address
class ConnectionParameterUpdateError(BaseError):
"""Connection Parameter Update Error"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# UUID # UUID
# #

View File

@@ -23,7 +23,18 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
TYPE_CHECKING,
)
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -130,6 +141,7 @@ from .core import (
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
AdvertisingData, AdvertisingData,
ConnectionParameterUpdateError,
CommandTimeoutError, CommandTimeoutError,
ConnectionPHY, ConnectionPHY,
InvalidStateError, InvalidStateError,
@@ -152,6 +164,9 @@ from . import sdp
from . import l2cap from . import l2cap
from . import core from . import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -709,6 +724,7 @@ class Connection(CompositeEventEmitter):
connection_interval_max, connection_interval_max,
max_latency, max_latency,
supervision_timeout, supervision_timeout,
use_l2cap=False,
): ):
return await self.device.update_connection_parameters( return await self.device.update_connection_parameters(
self, self,
@@ -716,6 +732,7 @@ class Connection(CompositeEventEmitter):
connection_interval_max, connection_interval_max,
max_latency, max_latency,
supervision_timeout, supervision_timeout,
use_l2cap=use_l2cap,
) )
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None): async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
@@ -942,7 +959,13 @@ class Device(CompositeEventEmitter):
pass pass
@classmethod @classmethod
def with_hci(cls, name, address, hci_source, hci_sink): def with_hci(
cls,
name: str,
address: Address,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
''' '''
Create a Device instance with a Host configured to communicate with a controller Create a Device instance with a Host configured to communicate with a controller
through an HCI source/sink through an HCI source/sink
@@ -951,18 +974,25 @@ class Device(CompositeEventEmitter):
return cls(name=name, address=address, host=host) return cls(name=name, address=address, host=host)
@classmethod @classmethod
def from_config_file(cls, filename): def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls(config=config) return cls(config=config)
@classmethod @classmethod
def from_config_with_hci(cls, config, hci_source, hci_sink): def from_config_with_hci(
cls,
config: DeviceConfiguration,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
host = Host(controller_source=hci_source, controller_sink=hci_sink) host = Host(controller_source=hci_source, controller_sink=hci_sink)
return cls(config=config, host=host) return cls(config=config, host=host)
@classmethod @classmethod
def from_config_file_with_hci(cls, filename, hci_source, hci_sink): def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink) return cls.from_config_with_hci(config, hci_source, hci_sink)
@@ -2083,11 +2113,30 @@ class Device(CompositeEventEmitter):
supervision_timeout, supervision_timeout,
min_ce_length=0, min_ce_length=0,
max_ce_length=0, max_ce_length=0,
): use_l2cap=False,
) -> None:
''' '''
NOTE: the name of the parameters may look odd, but it just follows the names NOTE: the name of the parameters may look odd, but it just follows the names
used in the Bluetooth spec. used in the Bluetooth spec.
''' '''
if use_l2cap:
if connection.role != BT_PERIPHERAL_ROLE:
raise InvalidStateError(
'only peripheral can update connection parameters with l2cap'
)
l2cap_result = (
await self.l2cap_channel_manager.update_connection_parameters(
connection,
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout,
)
)
if l2cap_result != l2cap.L2CAP_CONNECTION_PARAMETERS_ACCEPTED_RESULT:
raise ConnectionParameterUpdateError(l2cap_result)
result = await self.send_command( result = await self.send_command(
HCI_LE_Connection_Update_Command( HCI_LE_Connection_Update_Command(
connection_handle=connection.handle, connection_handle=connection.handle,
@@ -2097,7 +2146,7 @@ class Device(CompositeEventEmitter):
supervision_timeout=supervision_timeout, supervision_timeout=supervision_timeout,
min_ce_length=min_ce_length, min_ce_length=min_ce_length,
max_ce_length=max_ce_length, max_ce_length=max_ce_length,
) ) # type: ignore[call-arg]
) )
if result.status != HCI_Command_Status_Event.PENDING: if result.status != HCI_Command_Status_Event.PENDING:
raise HCI_StatusError(result) raise HCI_StatusError(result)
@@ -2238,9 +2287,11 @@ class Device(CompositeEventEmitter):
def request_pairing(self, connection): def request_pairing(self, connection):
return self.smp_manager.request_pairing(connection) return self.smp_manager.request_pairing(connection)
async def get_long_term_key(self, connection_handle, rand, ediv): async def get_long_term_key(
self, connection_handle: int, rand: bytes, ediv: int
) -> Optional[bytes]:
if (connection := self.lookup_connection(connection_handle)) is None: if (connection := self.lookup_connection(connection_handle)) is None:
return return None
# Start by looking for the key in an SMP session # Start by looking for the key in an SMP session
ltk = self.smp_manager.get_long_term_key(connection, rand, ediv) ltk = self.smp_manager.get_long_term_key(connection, rand, ediv)
@@ -2260,6 +2311,7 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value return keys.ltk_peripheral.value
return None
async def get_link_key(self, address: Address) -> Optional[bytes]: async def get_link_key(self, address: Address) -> Optional[bytes]:
if self.keystore is None: if self.keystore is None:

View File

@@ -23,7 +23,6 @@ import abc
import logging import logging
import pathlib import pathlib
import platform import platform
import platformdirs
from . import rtk from . import rtk
@@ -77,6 +76,8 @@ def project_data_dir() -> pathlib.Path:
A path to an OS-specific directory for bumble data. The directory is created if A path to an OS-specific directory for bumble data. The directory is created if
it doesn't exist. it doesn't exist.
""" """
import platformdirs
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id # platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
# out of author & project # out of author & project

View File

@@ -34,10 +34,9 @@ import weakref
from bumble.hci import ( from bumble.hci import (
hci_command_op_code, hci_vendor_command_op_code,
STATUS_SPEC, STATUS_SPEC,
HCI_SUCCESS, HCI_SUCCESS,
HCI_COMMAND_NAMES,
HCI_Command, HCI_Command,
HCI_Reset_Command, HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
@@ -179,8 +178,10 @@ RTK_USB_PRODUCTS = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HCI Commands # HCI Commands
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_command_op_code(0x3F, 0x6D) HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
HCI_COMMAND_NAMES[HCI_RTK_READ_ROM_VERSION_COMMAND] = "HCI_RTK_READ_ROM_VERSION_COMMAND" HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
HCI_Command.register_commands(globals())
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)]) @HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
@@ -188,10 +189,6 @@ class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
pass pass
HCI_RTK_DOWNLOAD_COMMAND = hci_command_op_code(0x3F, 0x20)
HCI_COMMAND_NAMES[HCI_RTK_DOWNLOAD_COMMAND] = "HCI_RTK_DOWNLOAD_COMMAND"
@HCI_Command.command( @HCI_Command.command(
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)], fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)], return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
@@ -200,10 +197,6 @@ class HCI_RTK_Download_Command(HCI_Command):
pass pass
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_command_op_code(0x3F, 0x66)
HCI_COMMAND_NAMES[HCI_RTK_DROP_FIRMWARE_COMMAND] = "HCI_RTK_DROP_FIRMWARE_COMMAND"
@HCI_Command.command() @HCI_Command.command()
class HCI_RTK_Drop_Firmware_Command(HCI_Command): class HCI_RTK_Drop_Firmware_Command(HCI_Command):
pass pass

View File

@@ -16,11 +16,11 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import struct
import collections import collections
import logging
import functools import functools
from typing import Dict, Type, Union, Callable, Any, Optional import logging
import struct
from typing import Any, Dict, Callable, Optional, Type, Union
from .colors import color from .colors import color
from .core import ( from .core import (
@@ -47,6 +47,10 @@ def hci_command_op_code(ogf, ocf):
return ogf << 10 | ocf return ogf << 10 | ocf
def hci_vendor_command_op_code(ocf):
return hci_command_op_code(HCI_VENDOR_OGF, ocf)
def key_with_value(dictionary, target_value): def key_with_value(dictionary, target_value):
for key, value in dictionary.items(): for key, value in dictionary.items():
if value == target_value: if value == target_value:
@@ -101,6 +105,8 @@ def phy_list_to_bits(phys):
# fmt: off # fmt: off
# pylint: disable=line-too-long # pylint: disable=line-too-long
HCI_VENDOR_OGF = 0x3F
# HCI Version # HCI Version
HCI_VERSION_BLUETOOTH_CORE_1_0B = 0 HCI_VERSION_BLUETOOTH_CORE_1_0B = 0
HCI_VERSION_BLUETOOTH_CORE_1_1 = 1 HCI_VERSION_BLUETOOTH_CORE_1_1 = 1
@@ -206,10 +212,8 @@ HCI_INQUIRY_RESPONSE_NOTIFICATION_EVENT = 0X56
HCI_AUTHENTICATED_PAYLOAD_TIMEOUT_EXPIRED_EVENT = 0X57 HCI_AUTHENTICATED_PAYLOAD_TIMEOUT_EXPIRED_EVENT = 0X57
HCI_SAM_STATUS_CHANGE_EVENT = 0X58 HCI_SAM_STATUS_CHANGE_EVENT = 0X58
HCI_EVENT_NAMES = { HCI_VENDOR_EVENT = 0xFF
event_code: event_name for (event_name, event_code) in globals().items()
if event_name.startswith('HCI_') and event_name.endswith('_EVENT')
}
# HCI Subevent Codes # HCI Subevent Codes
HCI_LE_CONNECTION_COMPLETE_EVENT = 0x01 HCI_LE_CONNECTION_COMPLETE_EVENT = 0x01
@@ -248,10 +252,6 @@ HCI_LE_TRANSMIT_POWER_REPORTING_EVENT = 0X21
HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT = 0X22 HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT = 0X22
HCI_LE_SUBRATE_CHANGE_EVENT = 0X23 HCI_LE_SUBRATE_CHANGE_EVENT = 0X23
HCI_SUBEVENT_NAMES = {
event_code: event_name for (event_name, event_code) in globals().items()
if event_name.startswith('HCI_LE_') and event_name.endswith('_EVENT') and event_code != HCI_LE_META_EVENT
}
# HCI Command # HCI Command
HCI_INQUIRY_COMMAND = hci_command_op_code(0x01, 0x0001) HCI_INQUIRY_COMMAND = hci_command_op_code(0x01, 0x0001)
@@ -557,10 +557,6 @@ HCI_LE_SET_DATA_RELATED_ADDRESS_CHANGES_COMMAND = hci_c
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D) HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D)
HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E) HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E)
HCI_COMMAND_NAMES = {
command_code: command_name for (command_name, command_code) in globals().items()
if command_name.startswith('HCI_') and command_name.endswith('_COMMAND')
}
# HCI Error Codes # HCI Error Codes
# See Bluetooth spec Vol 2, Part D - 1.3 LIST OF ERROR CODES # See Bluetooth spec Vol 2, Part D - 1.3 LIST OF ERROR CODES
@@ -1960,6 +1956,7 @@ class HCI_Command(HCI_Packet):
''' '''
hci_packet_type = HCI_COMMAND_PACKET hci_packet_type = HCI_COMMAND_PACKET
command_names: Dict[int, str] = {}
command_classes: Dict[int, Type[HCI_Command]] = {} command_classes: Dict[int, Type[HCI_Command]] = {}
@staticmethod @staticmethod
@@ -1970,9 +1967,9 @@ class HCI_Command(HCI_Packet):
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.op_code = key_with_value(HCI_COMMAND_NAMES, cls.name) cls.op_code = key_with_value(cls.command_names, cls.name)
if cls.op_code is None: if cls.op_code is None:
raise KeyError(f'command {cls.name} not found in HCI_COMMAND_NAMES') raise KeyError(f'command {cls.name} not found in command_names')
cls.fields = fields cls.fields = fields
cls.return_parameters_fields = return_parameters_fields cls.return_parameters_fields = return_parameters_fields
@@ -1991,6 +1988,18 @@ class HCI_Command(HCI_Packet):
return inner return inner
@staticmethod
def command_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
command_code: command_name
for (command_name, command_code) in symbols.items()
if command_name.startswith('HCI_') and command_name.endswith('_COMMAND')
}
@classmethod
def register_commands(cls, symbols: Dict[str, Any]) -> None:
cls.command_names.update(cls.command_map(symbols))
@staticmethod @staticmethod
def from_bytes(packet: bytes) -> HCI_Command: def from_bytes(packet: bytes) -> HCI_Command:
op_code, length = struct.unpack_from('<HB', packet, 1) op_code, length = struct.unpack_from('<HB', packet, 1)
@@ -2015,7 +2024,7 @@ class HCI_Command(HCI_Packet):
@staticmethod @staticmethod
def command_name(op_code): def command_name(op_code):
name = HCI_COMMAND_NAMES.get(op_code) name = HCI_Command.command_names.get(op_code)
if name is not None: if name is not None:
return name return name
return f'[OGF=0x{op_code >> 10:02x}, OCF=0x{op_code & 0x3FF:04x}]' return f'[OGF=0x{op_code >> 10:02x}, OCF=0x{op_code & 0x3FF:04x}]'
@@ -2024,6 +2033,16 @@ class HCI_Command(HCI_Packet):
def create_return_parameters(cls, **kwargs): def create_return_parameters(cls, **kwargs):
return HCI_Object(cls.return_parameters_fields, **kwargs) return HCI_Object(cls.return_parameters_fields, **kwargs)
@classmethod
def parse_return_parameters(cls, parameters):
if not cls.return_parameters_fields:
return None
return_parameters = HCI_Object.from_bytes(
parameters, 0, cls.return_parameters_fields
)
return_parameters.fields = cls.return_parameters_fields
return return_parameters
def __init__(self, op_code, parameters=None, **kwargs): def __init__(self, op_code, parameters=None, **kwargs):
super().__init__(HCI_Command.command_name(op_code)) super().__init__(HCI_Command.command_name(op_code))
if (fields := getattr(self, 'fields', None)) and kwargs: if (fields := getattr(self, 'fields', None)) and kwargs:
@@ -2053,6 +2072,9 @@ class HCI_Command(HCI_Packet):
return result return result
HCI_Command.register_commands(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
[ [
@@ -4308,8 +4330,8 @@ class HCI_Event(HCI_Packet):
''' '''
hci_packet_type = HCI_EVENT_PACKET hci_packet_type = HCI_EVENT_PACKET
event_names: Dict[int, str] = {}
event_classes: Dict[int, Type[HCI_Event]] = {} event_classes: Dict[int, Type[HCI_Event]] = {}
meta_event_classes: Dict[int, Type[HCI_LE_Meta_Event]] = {}
@staticmethod @staticmethod
def event(fields=()): def event(fields=()):
@@ -4319,9 +4341,9 @@ class HCI_Event(HCI_Packet):
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.event_code = key_with_value(HCI_EVENT_NAMES, cls.name) cls.event_code = key_with_value(cls.event_names, cls.name)
if cls.event_code is None: if cls.event_code is None:
raise KeyError('event not found in HCI_EVENT_NAMES') raise KeyError(f'event {cls.name} not found in event_names')
cls.fields = fields cls.fields = fields
# Patch the __init__ method to fix the event_code # Patch the __init__ method to fix the event_code
@@ -4337,12 +4359,30 @@ class HCI_Event(HCI_Packet):
return inner return inner
@staticmethod
def event_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
event_code: event_name
for (event_name, event_code) in symbols.items()
if event_name.startswith('HCI_')
and not event_name.startswith('HCI_LE_')
and event_name.endswith('_EVENT')
}
@staticmethod
def event_name(event_code):
return name_or_number(HCI_Event.event_names, event_code)
@staticmethod
def register_events(symbols: Dict[str, Any]) -> None:
HCI_Event.event_names.update(HCI_Event.event_map(symbols))
@staticmethod @staticmethod
def registered(event_class): def registered(event_class):
event_class.name = event_class.__name__.upper() event_class.name = event_class.__name__.upper()
event_class.event_code = key_with_value(HCI_EVENT_NAMES, event_class.name) event_class.event_code = key_with_value(HCI_Event.event_names, event_class.name)
if event_class.event_code is None: if event_class.event_code is None:
raise KeyError('event not found in HCI_EVENT_NAMES') raise KeyError(f'event {event_class.name} not found in event_names')
# Register a factory for this class # Register a factory for this class
HCI_Event.event_classes[event_class.event_code] = event_class HCI_Event.event_classes[event_class.event_code] = event_class
@@ -4362,11 +4402,16 @@ class HCI_Event(HCI_Packet):
# We do this dispatch here and not in the subclass in order to avoid call # We do this dispatch here and not in the subclass in order to avoid call
# loops # loops
subevent_code = parameters[0] subevent_code = parameters[0]
cls = HCI_Event.meta_event_classes.get(subevent_code) cls = HCI_LE_Meta_Event.subevent_classes.get(subevent_code)
if cls is None: if cls is None:
# No class registered, just use a generic class instance # No class registered, just use a generic class instance
return HCI_LE_Meta_Event(subevent_code, parameters) return HCI_LE_Meta_Event(subevent_code, parameters)
elif event_code == HCI_VENDOR_EVENT:
subevent_code = parameters[0]
cls = HCI_Vendor_Event.subevent_classes.get(subevent_code)
if cls is None:
# No class registered, just use a generic class instance
return HCI_Vendor_Event(subevent_code, parameters)
else: else:
cls = HCI_Event.event_classes.get(event_code) cls = HCI_Event.event_classes.get(event_code)
if cls is None: if cls is None:
@@ -4384,10 +4429,6 @@ class HCI_Event(HCI_Packet):
HCI_Object.init_from_bytes(self, parameters, 0, fields) HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self return self
@staticmethod
def event_name(event_code):
return name_or_number(HCI_EVENT_NAMES, event_code)
def __init__(self, event_code, parameters=None, **kwargs): def __init__(self, event_code, parameters=None, **kwargs):
super().__init__(HCI_Event.event_name(event_code)) super().__init__(HCI_Event.event_name(event_code))
if (fields := getattr(self, 'fields', None)) and kwargs: if (fields := getattr(self, 'fields', None)) and kwargs:
@@ -4414,71 +4455,111 @@ class HCI_Event(HCI_Packet):
return result return result
HCI_Event.register_events(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_LE_Meta_Event(HCI_Event): class HCI_Extended_Event(HCI_Event):
''' '''
See Bluetooth spec @ 7.7.65 LE Meta Event HCI_Event subclass for events that has a subevent code.
''' '''
@staticmethod subevent_names: Dict[int, str] = {}
def event(fields=()): subevent_classes: Dict[int, Type[HCI_Extended_Event]]
@classmethod
def event(cls, fields=()):
''' '''
Decorator used to declare and register subclasses Decorator used to declare and register subclasses
''' '''
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.subevent_code = key_with_value(HCI_SUBEVENT_NAMES, cls.name) cls.subevent_code = key_with_value(cls.subevent_names, cls.name)
if cls.subevent_code is None: if cls.subevent_code is None:
raise KeyError('subevent not found in HCI_SUBEVENT_NAMES') raise KeyError(f'subevent {cls.name} not found in subevent_names')
cls.fields = fields cls.fields = fields
# Patch the __init__ method to fix the subevent_code # Patch the __init__ method to fix the subevent_code
original_init = cls.__init__
def init(self, parameters=None, **kwargs): def init(self, parameters=None, **kwargs):
return HCI_LE_Meta_Event.__init__( return original_init(self, cls.subevent_code, parameters, **kwargs)
self, cls.subevent_code, parameters, **kwargs
)
cls.__init__ = init cls.__init__ = init
# Register a factory for this class # Register a factory for this class
HCI_Event.meta_event_classes[cls.subevent_code] = cls cls.subevent_classes[cls.subevent_code] = cls
return cls return cls
return inner return inner
@classmethod
def subevent_name(cls, subevent_code):
subevent_name = cls.subevent_names.get(subevent_code)
if subevent_name is not None:
return subevent_name
return f'{cls.__name__.upper()}[0x{subevent_code:02X}]'
@staticmethod
def subevent_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
subevent_code: subevent_name
for (subevent_name, subevent_code) in symbols.items()
if subevent_name.startswith('HCI_') and subevent_name.endswith('_EVENT')
}
@classmethod
def register_subevents(cls, symbols: Dict[str, Any]) -> None:
cls.subevent_names.update(cls.subevent_map(symbols))
@classmethod @classmethod
def from_parameters(cls, parameters): def from_parameters(cls, parameters):
self = cls.__new__(cls) self = cls.__new__(cls)
HCI_LE_Meta_Event.__init__(self, self.subevent_code, parameters) HCI_Extended_Event.__init__(self, self.subevent_code, parameters)
if fields := getattr(self, 'fields', None): if fields := getattr(self, 'fields', None):
HCI_Object.init_from_bytes(self, parameters, 1, fields) HCI_Object.init_from_bytes(self, parameters, 1, fields)
return self return self
@staticmethod
def subevent_name(subevent_code):
return name_or_number(HCI_SUBEVENT_NAMES, subevent_code)
def __init__(self, subevent_code, parameters, **kwargs): def __init__(self, subevent_code, parameters, **kwargs):
self.subevent_code = subevent_code self.subevent_code = subevent_code
if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs: if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs:
parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes( parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes(
kwargs, fields kwargs, fields
) )
super().__init__(HCI_LE_META_EVENT, parameters, **kwargs) super().__init__(self.event_code, parameters, **kwargs)
# Override the name in order to adopt the subevent name instead # Override the name in order to adopt the subevent name instead
self.name = self.subevent_name(subevent_code) self.name = self.subevent_name(subevent_code)
def __str__(self):
result = color(self.subevent_name(self.subevent_code), 'magenta') # -----------------------------------------------------------------------------
if fields := getattr(self, 'fields', None): class HCI_LE_Meta_Event(HCI_Extended_Event):
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') '''
else: See Bluetooth spec @ 7.7.65 LE Meta Event
if self.parameters: '''
result += f': {self.parameters.hex()}'
return result event_code: int = HCI_LE_META_EVENT
subevent_classes = {}
@staticmethod
def subevent_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
subevent_code: subevent_name
for (subevent_name, subevent_code) in symbols.items()
if subevent_name.startswith('HCI_LE_') and subevent_name.endswith('_EVENT')
}
HCI_LE_Meta_Event.register_subevents(globals())
# -----------------------------------------------------------------------------
class HCI_Vendor_Event(HCI_Extended_Event):
event_code: int = HCI_VENDOR_EVENT
subevent_classes = {}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -4592,7 +4673,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event):
return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}' return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}'
HCI_Event.meta_event_classes[ HCI_LE_Meta_Event.subevent_classes[
HCI_LE_ADVERTISING_REPORT_EVENT HCI_LE_ADVERTISING_REPORT_EVENT
] = HCI_LE_Advertising_Report_Event ] = HCI_LE_Advertising_Report_Event
@@ -4846,7 +4927,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event):
return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}' return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}'
HCI_Event.meta_event_classes[ HCI_LE_Meta_Event.subevent_classes[
HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT
] = HCI_LE_Extended_Advertising_Report_Event ] = HCI_LE_Extended_Advertising_Report_Event
@@ -5120,11 +5201,11 @@ class HCI_Command_Complete_Event(HCI_Event):
self.return_parameters = self.return_parameters[0] self.return_parameters = self.return_parameters[0]
else: else:
cls = HCI_Command.command_classes.get(self.command_opcode) cls = HCI_Command.command_classes.get(self.command_opcode)
if cls and cls.return_parameters_fields: if cls:
self.return_parameters = HCI_Object.from_bytes( # Try to parse the return parameters bytes into an object.
self.return_parameters, 0, cls.return_parameters_fields return_parameters = cls.parse_return_parameters(self.return_parameters)
) if return_parameters is not None:
self.return_parameters.fields = cls.return_parameters_fields self.return_parameters = return_parameters
return self return self

View File

@@ -15,16 +15,19 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import collections.abc
import logging import logging
import asyncio import asyncio
import dataclasses import dataclasses
import enum import enum
import traceback import traceback
from typing import Dict, List, Union, Set import warnings
from typing import Dict, List, Union, Set, TYPE_CHECKING
from . import at from . import at
from . import rfcomm from . import rfcomm
from bumble.colors import color
from bumble.core import ( from bumble.core import (
ProtocolError, ProtocolError,
BT_GENERIC_AUDIO_SERVICE, BT_GENERIC_AUDIO_SERVICE,
@@ -48,6 +51,71 @@ from bumble.sdp import (
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Error
# -----------------------------------------------------------------------------
class HfpProtocolError(ProtocolError):
def __init__(self, error_name: str = '', details: str = ''):
super().__init__(None, 'hfp', error_name, details)
# -----------------------------------------------------------------------------
# Protocol Support
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class HfpProtocol:
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
warnings.warn("See HfProtocol", DeprecationWarning)
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
dlc.sink = self.feed
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1 :]
if len(line) > 0:
self.on_line(line)
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
self.lines_available.clear()
logger.debug(color(f'<<< {line}', 'green'))
return line
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Normative protocol definitions # Normative protocol definitions
@@ -302,8 +370,12 @@ class HfProtocol:
dlc: rfcomm.DLC dlc: rfcomm.DLC
command_lock: asyncio.Lock command_lock: asyncio.Lock
response_queue: asyncio.Queue if TYPE_CHECKING:
unsolicited_queue: asyncio.Queue response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray read_buffer: bytearray
def __init__(self, dlc: rfcomm.DLC, configuration: Configuration): def __init__(self, dlc: rfcomm.DLC, configuration: Configuration):
@@ -368,7 +440,7 @@ class HfProtocol:
else: else:
logger.warning(f"dropping unexpected response with code '{response.code}'") logger.warning(f"dropping unexpected response with code '{response.code}'")
# Send an AT command and wait for the peer resposne. # Send an AT command and wait for the peer response.
# Wait for the AT responses sent by the peer, to the status code. # Wait for the AT responses sent by the peer, to the status code.
# Raises asyncio.TimeoutError if the status is not received # Raises asyncio.TimeoutError if the status is not received
# after a timeout (default 1 second). # after a timeout (default 1 second).
@@ -390,7 +462,7 @@ class HfProtocol:
) )
if result.code == 'OK': if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1: if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise ProtocolError("NO ANSWER") raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE: if response_type == AtResponseType.MULTIPLE:
return responses return responses
@@ -398,7 +470,7 @@ class HfProtocol:
return responses[0] return responses[0]
return None return None
if result.code in STATUS_CODES: if result.code in STATUS_CODES:
raise ProtocolError(result.code) raise HfpProtocolError(result.code)
responses.append(result) responses.append(result)
# 4.2.1 Service Level Connection Initialization. # 4.2.1 Service Level Connection Initialization.

View File

@@ -21,7 +21,7 @@ import collections
import logging import logging
import struct import struct
from typing import Optional from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
@@ -73,10 +73,14 @@ from .core import (
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
ConnectionPHY, ConnectionPHY,
ConnectionParameters, ConnectionParameters,
InvalidStateError,
) )
from .utils import AbortableEventEmitter from .utils import AbortableEventEmitter
from .transport.common import TransportLostError from .transport.common import TransportLostError
if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -116,10 +120,21 @@ class Connection:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None): connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]
def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__() super().__init__()
self.hci_sink = None
self.hci_metadata = None self.hci_metadata = None
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.reset_done = False self.reset_done = False
@@ -299,7 +314,7 @@ class Host(AbortableEventEmitter):
self.reset_done = True self.reset_done = True
@property @property
def controller(self): def controller(self) -> TransportSink:
return self.hci_sink return self.hci_sink
@controller.setter @controller.setter
@@ -308,13 +323,12 @@ class Host(AbortableEventEmitter):
if controller: if controller:
controller.set_packet_sink(self) controller.set_packet_sink(self)
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink self.hci_sink = sink
def send_hci_packet(self, packet: HCI_Packet) -> None: def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet)) self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):

View File

@@ -1387,6 +1387,7 @@ class ChannelManager:
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host] _host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]]
def __init__( def __init__(
self, self,
@@ -1408,6 +1409,7 @@ class ChannelManager:
self.le_coc_requests = {} # LE CoC connection requests, by identifier self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None
@property @property
def host(self) -> Host: def host(self) -> Host:
@@ -1865,11 +1867,45 @@ class ChannelManager:
), ),
) )
async def update_connection_parameters(
self,
connection: Connection,
interval_min: int,
interval_max: int,
latency: int,
timeout: int,
) -> int:
# Check that there isn't already a request pending
if self.connection_parameters_update_response:
raise InvalidStateError('request already pending')
self.connection_parameters_update_response = (
asyncio.get_running_loop().create_future()
)
self.send_control_frame(
connection,
L2CAP_LE_SIGNALING_CID,
L2CAP_Connection_Parameter_Update_Request(
interval_min=interval_min,
interval_max=interval_max,
latency=latency,
timeout=timeout,
),
)
return await self.connection_parameters_update_response
def on_l2cap_connection_parameter_update_response( def on_l2cap_connection_parameter_update_response(
self, connection: Connection, cid: int, response self, connection: Connection, cid: int, response
) -> None: ) -> None:
# TODO: check response if self.connection_parameters_update_response:
pass self.connection_parameters_update_response.set_result(response.result)
self.connection_parameters_update_response = None
else:
logger.warning(
color(
'received l2cap_connection_parameter_update_response without a pending request',
'red',
)
)
def on_l2cap_le_credit_based_connection_request( def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request self, connection: Connection, cid: int, request
@@ -2078,7 +2114,8 @@ class ChannelManager:
# Connect # Connect
try: try:
await channel.connect() await channel.connect()
except Exception: except Exception as e:
del connection_channels[source_cid] del connection_channels[source_cid]
raise e
return channel return channel

View File

@@ -20,13 +20,29 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import enum import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Tuple, Callable, Dict, Union, TYPE_CHECKING
from . import core, l2cap from . import core, l2cap
from .colors import color from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError from .core import (
UUID,
BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID,
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -111,6 +127,50 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on # fmt: on
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
)
)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int: def compute_fcs(buffer: bytes) -> int:
result = 0xFF result = 0xFF

View File

@@ -20,7 +20,6 @@ import logging
import os import os
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -119,7 +118,8 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None) assert spec is not None
return await open_file_transport(spec[0])
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
@@ -134,12 +134,14 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None) assert spec is not None
return await open_usb_transport(spec[0])
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None) assert spec is not None
return await open_pyusb_transport(spec[0])
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
@@ -168,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport:
""" """
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
from ..controller import Controller
from ..link import RemoteLink # lazy import from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])

View File

@@ -18,7 +18,7 @@
import logging import logging
import grpc.aio import grpc.aio
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec): async def open_android_emulator_transport(spec: str | None) -> Transport:
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -66,7 +66,7 @@ async def open_android_emulator_transport(spec):
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = 8554 server_port = '8554'
if spec is not None: if spec is not None:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
@@ -82,6 +82,7 @@ async def open_android_emulator_transport(spec):
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)

View File

@@ -121,7 +121,9 @@ def publish_grpc_port(grpc_port) -> bool:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port): async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int
) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise ValueError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:

View File

@@ -20,11 +20,12 @@ import contextlib
import struct import struct
import asyncio import asyncio
import logging import logging
from typing import ContextManager import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from .. import hci from bumble import hci
from ..colors import color from bumble.colors import color
from ..snoop import Snooper from bumble.snoop import Snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -36,7 +37,7 @@ logger = logging.getLogger(__name__)
# Information needed to parse HCI packets with a generic parser: # Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents: # For each packet type, the info represents:
# (length-size, length-offset, unpack-type) # (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = { HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
@@ -44,6 +45,8 @@ HCI_PACKET_INFO = {
} }
# -----------------------------------------------------------------------------
# Errors
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TransportLostError(Exception): class TransportLostError(Exception):
""" """
@@ -51,24 +54,36 @@ class TransportLostError(Exception):
""" """
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None:
...
class TransportSource(Protocol):
terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None:
...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketPump: class PacketPump:
""" """
Pump HCI packets from a reader to a sink. Pump HCI packets from a reader to a sink.
""" """
def __init__(self, reader, sink): def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
self.reader = reader self.reader = reader
self.sink = sink self.sink = sink
async def run(self): async def run(self) -> None:
while True: while True:
try: try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink # Deliver the packet to the sink
self.sink.on_packet(packet) self.sink.on_packet(await self.reader.next_packet())
except Exception as error: except Exception as error:
logger.warning(f'!!! {error}') logger.warning(f'!!! {error}')
@@ -86,18 +101,22 @@ class PacketParser:
NEED_LENGTH = 1 NEED_LENGTH = 1
NEED_BODY = 2 NEED_BODY = 2
def __init__(self, sink=None): sink: Optional[TransportSink]
extended_packet_info: Dict[int, Tuple[int, int, str]]
packet_info: Optional[Tuple[int, int, str]] = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink self.sink = sink
self.extended_packet_info = {} self.extended_packet_info = {}
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.state = PacketParser.NEED_TYPE self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1 self.bytes_needed = 1
self.packet = bytearray() self.packet = bytearray()
self.packet_info = None self.packet_info = None
def feed_data(self, data): def feed_data(self, data: bytes) -> None:
data_offset = 0 data_offset = 0
data_left = len(data) data_left = len(data)
while data_left and self.bytes_needed: while data_left and self.bytes_needed:
@@ -118,6 +137,7 @@ class PacketParser:
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
assert self.packet_info is not None
body_length = struct.unpack_from( body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1] self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0] )[0]
@@ -135,7 +155,7 @@ class PacketParser:
) )
self.reset() self.reset()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
@@ -145,10 +165,10 @@ class PacketReader:
Reader that reads HCI packets from a sync source. Reader that reads HCI packets from a sync source.
""" """
def __init__(self, source): def __init__(self, source: io.BufferedReader) -> None:
self.source = source self.source = source
def next_packet(self): def next_packet(self) -> Optional[bytes]:
# Get the packet type # Get the packet type
packet_type = self.source.read(1) packet_type = self.source.read(1)
if len(packet_type) != 1: if len(packet_type) != 1:
@@ -157,7 +177,7 @@ class PacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -180,17 +200,17 @@ class AsyncPacketReader:
Reader that reads HCI packets from an async source. Reader that reads HCI packets from an async source.
""" """
def __init__(self, source): def __init__(self, source: asyncio.StreamReader) -> None:
self.source = source self.source = source
async def next_packet(self): async def next_packet(self) -> bytes:
# Get the packet type # Get the packet type
packet_type = await self.source.readexactly(1) packet_type = await self.source.readexactly(1)
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -209,11 +229,11 @@ class AsyncPipeSink:
Sink that forwards packets asynchronously to another sink. Sink that forwards packets asynchronously to another sink.
""" """
def __init__(self, sink): def __init__(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.loop.call_soon(self.sink.on_packet, packet) self.loop.call_soon(self.sink.on_packet, packet)
@@ -223,50 +243,48 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
terminated: asyncio.Future terminated: asyncio.Future[None]
parser: PacketParser parser: PacketParser
def __init__(self): def __init__(self) -> None:
self.parser = PacketParser() self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink) self.parser.set_packet_sink(sink)
def on_transport_lost(self): def on_transport_lost(self) -> None:
self.terminated.set_result(None) self.terminated.set_result(None)
if self.parser.sink: if self.parser.sink:
try: if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost() self.parser.sink.on_transport_lost()
except AttributeError:
pass
async def wait_for_termination(self): async def wait_for_termination(self) -> None:
""" """
Convenience method for backward compatibility. Prefer using the `terminated` Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead. attribute instead.
""" """
return await self.terminated return await self.terminated
def close(self): def close(self) -> None:
pass pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data): def data_received(self, data: bytes) -> None:
self.parser.feed_data(data) self.parser.feed_data(data)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSink: class StreamPacketSink:
def __init__(self, transport): def __init__(self, transport: asyncio.WriteTransport) -> None:
self.transport = transport self.transport = transport
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.transport.write(packet) self.transport.write(packet)
def close(self): def close(self) -> None:
self.transport.close() self.transport.close()
@@ -286,7 +304,7 @@ class Transport:
... ...
""" """
def __init__(self, source, sink): def __init__(self, source: TransportSource, sink: TransportSink) -> None:
self.source = source self.source = source
self.sink = sink self.sink = sink
@@ -300,19 +318,23 @@ class Transport:
return iter((self.source, self.sink)) return iter((self.source, self.sink))
async def close(self) -> None: async def close(self) -> None:
self.source.close() if hasattr(self.source, 'close'):
self.sink.close() self.source.close()
if hasattr(self.sink, 'close'):
self.sink.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource): class PumpedPacketSource(ParserSource):
def __init__(self, receive): pump_task: Optional[asyncio.Task[None]]
def __init__(self, receive) -> None:
super().__init__() super().__init__()
self.receive_function = receive self.receive_function = receive
self.pump_task = None self.pump_task = None
def start(self): def start(self) -> None:
async def pump_packets(): async def pump_packets() -> None:
while True: while True:
try: try:
packet = await self.receive_function() packet = await self.receive_function()
@@ -322,12 +344,12 @@ class PumpedPacketSource(ParserSource):
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error) self.terminated.set_exception(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self) -> None:
if self.pump_task: if self.pump_task:
self.pump_task.cancel() self.pump_task.cancel()
@@ -339,7 +361,7 @@ class PumpedPacketSink:
self.packet_queue = asyncio.Queue() self.packet_queue = asyncio.Queue()
self.pump_task = None self.pump_task = None
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet) self.packet_queue.put_nowait(packet)
def start(self): def start(self):
@@ -364,15 +386,23 @@ class PumpedPacketSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedTransport(Transport): class PumpedTransport(Transport):
def __init__(self, source, sink, close_function): source: PumpedPacketSource
sink: PumpedPacketSink
def __init__(
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
close_function,
) -> None:
super().__init__(source, sink) super().__init__(source, sink)
self.close_function = close_function self.close_function = close_function
def start(self): def start(self) -> None:
self.source.start() self.source.start()
self.sink.start() self.sink.start()
async def close(self): async def close(self) -> None:
await super().close() await super().close()
await self.close_function() await self.close_function()
@@ -397,31 +427,38 @@ class SnoopingTransport(Transport):
raise RuntimeError('unexpected code path') # Satisfy the type checker raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source: class Source:
def __init__(self, source, snooper): sink: TransportSink
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source self.source = source
self.snooper = snooper self.snooper = snooper
self.sink = None self.terminated = source.terminated
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.source.set_packet_sink(self) self.source.set_packet_sink(self)
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
class Sink: class Sink:
def __init__(self, sink, snooper): def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
self.sink = sink self.sink = sink
self.snooper = snooper self.snooper = snooper
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None): def __init__(
self,
transport: Transport,
snooper: Snooper,
close_snooper=None,
) -> None:
super().__init__( super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
) )

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_file_transport(spec): async def open_file_transport(spec: str) -> Transport:
''' '''
Open a File transport (typically not for a real file, but for a PTY or other unix Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files). virtual files).

View File

@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec): async def open_hci_socket_transport(spec: str | None) -> Transport:
''' '''
Open an HCI Socket (only available on some platforms). Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter) The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -47,7 +47,7 @@ async def open_hci_socket_transport(spec):
hci_socket = socket.socket( hci_socket = socket.socket(
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI, socket.BTPROTO_HCI, # type: ignore
) )
except AttributeError as error: except AttributeError as error:
# Not supported on this platform # Not supported on this platform

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pty_transport(spec): async def open_pty_transport(spec: str | None) -> Transport:
''' '''
Open a PTY transport. Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link The parameter string may be empty, or a path name where a symbolic link

View File

@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pyusb_transport(spec): async def open_pyusb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. [Implementation based on PyUSB] Open a USB transport. [Implementation based on PyUSB]
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_serial_transport(spec): async def open_serial_transport(spec: str) -> Transport:
''' '''
Open a serial port transport. Open a serial port transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_client_transport(spec): async def open_tcp_client_transport(spec: str) -> Transport:
''' '''
Open a TCP client transport. Open a TCP client transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_server_transport(spec): async def open_tcp_server_transport(spec: str) -> Transport:
''' '''
Open a TCP server transport. Open a TCP server transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -42,7 +43,7 @@ async def open_tcp_server_transport(spec):
async def close(self): async def close(self):
await super().close() await super().close()
class TcpServerProtocol: class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink): def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source self.packet_source = packet_source
self.packet_sink = packet_sink self.packet_sink = packet_sink

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_udp_transport(spec): async def open_udp_transport(spec: str) -> Transport:
''' '''
Open a UDP transport. Open a UDP transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -60,7 +60,7 @@ def load_libusb():
usb1.loadLibrary(libusb_dll) usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec): async def open_usb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. Open a USB transport.
The moniker string has this syntax: The moniker string has this syntax:

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from .common import Transport
from .file import open_file_transport from .file import open_file_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_vhci_transport(spec): async def open_vhci_transport(spec: str | None) -> Transport:
''' '''
Open a VHCI transport (only available on some platforms). Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device The parameter string is either empty (to use the default VHCI device
@@ -42,15 +43,15 @@ async def open_vhci_transport(spec):
# Override the source's `data_received` method so that we can # Override the source's `data_received` method so that we can
# filter out the vendor packet that is received just after the # filter out the vendor packet that is received just after the
# initial open # initial open
def vhci_data_received(data): def vhci_data_received(data: bytes) -> None:
if len(data) > 0 and data[0] == HCI_VENDOR_PKT: if len(data) > 0 and data[0] == HCI_VENDOR_PKT:
if len(data) == 4: if len(data) == 4:
hci_index = data[2] << 8 | data[3] hci_index = data[2] << 8 | data[3]
logger.info(f'HCI index {hci_index}') logger.info(f'HCI index {hci_index}')
else: else:
transport.source.parser.feed_data(data) transport.source.parser.feed_data(data) # type: ignore
transport.source.data_received = vhci_data_received transport.source.data_received = vhci_data_received # type: ignore
# Write the initial config # Write the initial config
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))

View File

@@ -16,9 +16,9 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import websockets import websockets.client
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_client_transport(spec): async def open_ws_client_transport(spec: str) -> Transport:
''' '''
Open a WebSocket client transport. Open a WebSocket client transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -38,7 +38,7 @@ async def open_ws_client_transport(spec):
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.split(':')
uri = f'ws://{remote_host}:{remote_port}' uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.connect(uri) websocket = await websockets.client.connect(uri)
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),

View File

@@ -15,7 +15,6 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import websockets import websockets
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_server_transport(spec): async def open_ws_server_transport(spec: str) -> Transport:
''' '''
Open a WebSocket server transport. Open a WebSocket server transport.
The parameter string has this syntax: The parameter string has this syntax:

0
bumble/vendor/__init__.py vendored Normal file
View File

0
bumble/vendor/android/__init__.py vendored Normal file
View File

318
bumble/vendor/android/hci.py vendored Normal file
View File

@@ -0,0 +1,318 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from bumble.hci import (
name_or_number,
hci_vendor_command_op_code,
Address,
HCI_Constant,
HCI_Object,
HCI_Command,
HCI_Vendor_Event,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Android Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#chip-capabilities-and-configuration
HCI_LE_GET_VENDOR_CAPABILITIES_COMMAND = hci_vendor_command_op_code(0x153)
HCI_LE_APCF_COMMAND = hci_vendor_command_op_code(0x157)
HCI_GET_CONTROLLER_ACTIVITY_ENERGY_INFO_COMMAND = hci_vendor_command_op_code(0x159)
HCI_A2DP_HARDWARE_OFFLOAD_COMMAND = hci_vendor_command_op_code(0x15D)
HCI_BLUETOOTH_QUALITY_REPORT_COMMAND = hci_vendor_command_op_code(0x15E)
HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
HCI_Command.register_commands(globals())
HCI_Vendor_Event.register_subevents(globals())
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
)
class HCI_LE_Get_Vendor_Capabilities_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = HCI_Object(cls.return_parameters_fields, **nones)
try:
offset = 0
for field in cls.return_parameters_fields:
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(
parameters, offset, field_type
)
setattr(return_parameters, field_name, field_value)
offset += field_size
except struct.error:
pass
return return_parameters
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_LE_APCF_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
APCF_ENABLE = 0x00
APCF_SET_FILTERING_PARAMETERS = 0x01
APCF_BROADCASTER_ADDRESS = 0x02
APCF_SERVICE_UUID = 0x03
APCF_SERVICE_SOLICITATION_UUID = 0x04
APCF_LOCAL_NAME = 0x05
APCF_MANUFACTURER_DATA = 0x06
APCF_SERVICE_DATA = 0x07
APCF_TRANSPORT_DISCOVERY_SERVICE = 0x08
APCF_AD_TYPE_FILTER = 0x09
APCF_READ_EXTENDED_FEATURES = 0xFF
OPCODE_NAMES = {
APCF_ENABLE: 'APCF_ENABLE',
APCF_SET_FILTERING_PARAMETERS: 'APCF_SET_FILTERING_PARAMETERS',
APCF_BROADCASTER_ADDRESS: 'APCF_BROADCASTER_ADDRESS',
APCF_SERVICE_UUID: 'APCF_SERVICE_UUID',
APCF_SERVICE_SOLICITATION_UUID: 'APCF_SERVICE_SOLICITATION_UUID',
APCF_LOCAL_NAME: 'APCF_LOCAL_NAME',
APCF_MANUFACTURER_DATA: 'APCF_MANUFACTURER_DATA',
APCF_SERVICE_DATA: 'APCF_SERVICE_DATA',
APCF_TRANSPORT_DISCOVERY_SERVICE: 'APCF_TRANSPORT_DISCOVERY_SERVICE',
APCF_AD_TYPE_FILTER: 'APCF_AD_TYPE_FILTER',
APCF_READ_EXTENDED_FEATURES: 'APCF_READ_EXTENDED_FEATURES',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
],
)
class HCI_Get_Controller_Activity_Energy_Info_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_A2DP_Hardware_Offload_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
OPCODE_NAMES = {
START_A2DP_OFFLOAD: 'START_A2DP_OFFLOAD',
STOP_A2DP_OFFLOAD: 'STOP_A2DP_OFFLOAD',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
OPCODE_NAMES = {
GET_AUDIO_BUFFER_TIME_CAPABILITY: 'GET_AUDIO_BUFFER_TIME_CAPABILITY',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Vendor_Event.event(
fields=[
('quality_report_id', 1),
('packet_types', 1),
('connection_handle', 2),
('connection_role', {'size': 1, 'mapper': HCI_Constant.role_name}),
('tx_power_level', -1),
('rssi', -1),
('snr', 1),
('unused_afh_channel_count', 1),
('afh_select_unideal_channel_count', 1),
('lsto', 2),
('connection_piconet_clock', 4),
('retransmission_count', 4),
('no_rx_count', 4),
('nak_count', 4),
('last_tx_ack_timestamp', 4),
('flow_off_count', 4),
('last_flow_on_timestamp', 4),
('buffer_overflow_bytes', 4),
('buffer_underflow_bytes', 4),
('bdaddr', Address.parse_address),
('cal_failed_item_count', 1),
('tx_total_packets', 4),
('tx_unacked_packets', 4),
('tx_flushed_packets', 4),
('tx_last_subevent_packets', 4),
('crc_error_packets', 4),
('rx_duplicate_packets', 4),
('vendor_specific_parameters', '*'),
]
)
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
'''

0
bumble/vendor/zephyr/__init__.py vendored Normal file
View File

88
bumble/vendor/zephyr/hci.py vendored Normal file
View File

@@ -0,0 +1,88 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.hci import (
hci_vendor_command_op_code,
HCI_Command,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Zephyr RTOS Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
HCI_WRITE_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000E)
HCI_READ_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000F)
HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
class TX_Power_Level_Command:
'''
Base class for read and write TX power level HCI commands
'''
TX_POWER_HANDLE_TYPE_ADV = 0x00
TX_POWER_HANDLE_TYPE_SCAN = 0x01
TX_POWER_HANDLE_TYPE_CONN = 0x02
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2), ('tx_power_level', -1)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
],
)
class HCI_Write_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
],
)
class HCI_Read_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''

View File

@@ -64,6 +64,7 @@ nav:
- Linux: platforms/linux.md - Linux: platforms/linux.md
- Windows: platforms/windows.md - Windows: platforms/windows.md
- Android: platforms/android.md - Android: platforms/android.md
- Zephyr: platforms/zephyr.md
- Examples: - Examples:
- Overview: examples/index.md - Overview: examples/index.md

Binary file not shown.

View File

@@ -9,3 +9,4 @@ For platform-specific information, see the following pages:
* :material-linux: Linux - see the [Linux platform page](linux.md) * :material-linux: Linux - see the [Linux platform page](linux.md)
* :material-microsoft-windows: Windows - see the [Windows platform page](windows.md) * :material-microsoft-windows: Windows - see the [Windows platform page](windows.md)
* :material-android: Android - see the [Android platform page](android.md) * :material-android: Android - see the [Android platform page](android.md)
* :material-memory: Zephyr - see the [Zephyr platform page](zephyr.md)

View File

@@ -0,0 +1,51 @@
:material-memory: ZEPHYR PLATFORM
=================================
Set TX Power on nRF52840
------------------------
The Nordic nRF52840 supports Zephyr's vendor specific HCI command for setting TX
power during advertising, connection, or scanning. With the example [HCI
USB](https://docs.zephyrproject.org/latest/samples/bluetooth/hci_usb/README.html)
application, an [nRF52840
dongle](https://www.nordicsemi.com/Products/Development-
hardware/nRF52840-Dongle) can be used as a Bumble controller.
To add dynamic TX power support to the HCI USB application, add the following to
`zephyr/samples/bluetooth/hci_usb/prj.conf` and build.
```
CONFIG_BT_CTLR_ADVANCED_FEATURES=y
CONFIG_BT_CTLR_CONN_RSSI=y
CONFIG_BT_CTLR_TX_PWR_DYNAMIC_CONTROL=y
```
Alternatively, a prebuilt firmware application can be downloaded here:
[hci_usb.zip](../downloads/zephyr/hci_usb.zip).
Put the nRF52840 dongle into bootloader mode by pressing the RESET button. The
LED should pulse red. Load the firmware application with the `nrfutil` tool:
```
nrfutil dfu usb-serial -pkg hci_usb.zip -p /dev/ttyACM0
```
The vendor specific HCI commands to read and write TX power are defined in
`bumble/vendor/zephyr/hci.py` and may be used as such:
```python
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB
response = await host.send_command(
HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0,
tx_power_level=-4,
)
)
if response.return_parameters.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}")
```

View File

@@ -16,11 +16,9 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import collections
import sys import sys
import os import os
import logging import logging
from typing import Union
from bumble.colors import color from bumble.colors import color
@@ -32,8 +30,7 @@ from bumble.core import (
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
) )
from bumble import rfcomm from bumble import rfcomm, hfp
from bumble.rfcomm import Client
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client, Client as SDP_Client,
DataElement, DataElement,
@@ -47,61 +44,6 @@ from bumble.sdp import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol Support
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class HfpProtocol:
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
dlc.sink = self.feed
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1 :]
if len(line) > 0:
self.on_line(line)
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
self.lines_available.clear()
logger.debug(color(f'<<< {line}', 'green'))
return line
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable-next=too-many-nested-blocks # pylint: disable-next=too-many-nested-blocks
async def list_rfcomm_channels(device, connection): async def list_rfcomm_channels(device, connection):
@@ -241,7 +183,7 @@ async def main():
# Create a client and start it # Create a client and start it
print('@@@ Starting to RFCOMM client...') print('@@@ Starting to RFCOMM client...')
rfcomm_client = Client(device, connection) rfcomm_client = rfcomm.Client(device, connection)
rfcomm_mux = await rfcomm_client.start() rfcomm_mux = await rfcomm_client.start()
print('@@@ Started') print('@@@ Started')
@@ -256,7 +198,7 @@ async def main():
return return
# Protocol loop (just for testing at this point) # Protocol loop (just for testing at this point)
protocol = HfpProtocol(session) protocol = hfp.HfpProtocol(session)
while True: while True:
line = await protocol.next_line() line = await protocol.next_line()

View File

@@ -20,83 +20,109 @@ import sys
import os import os
import logging import logging
from bumble.core import UUID
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.core import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID, UUID
from bumble.rfcomm import Server from bumble.rfcomm import Server
from bumble.sdp import ( from bumble.utils import AsyncRunner
DataElement, from bumble.rfcomm import make_service_sdp_records
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(channel): def sdp_records(channel, uuid):
service_record_handle = 0x00010001
return { return {
0x00010001: [ service_record_handle: make_service_sdp_records(
ServiceAttribute( service_record_handle, channel, UUID(uuid)
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, )
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_dlc(dlc): def on_rfcomm_session(rfcomm_session, tcp_server):
print('*** DLC connected', dlc) print('*** RFComm session connected', rfcomm_session)
dlc.sink = lambda data: on_rfcomm_data_received(dlc, data) tcp_server.attach_session(rfcomm_session)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_rfcomm_data_received(dlc, data): class TcpServerProtocol(asyncio.Protocol):
print(f'<<< Data received: {data.hex()}') def __init__(self, server):
try: self.server = server
message = data.decode('utf-8')
print(f'<<< Message = {message}')
except Exception:
pass
# Echo everything back def connection_made(self, transport):
dlc.write(data) peer_name = transport.get_extra_info('peer_name')
print(f'<<< TCP Server: connection from {peer_name}')
if self.server:
self.server.tcp_transport = transport
else:
transport.close()
def connection_lost(self, exc):
print('<<< TCP Server: connection lost')
if self.server:
self.server.tcp_transport = None
def data_received(self, data):
print(f'<<< TCP Server: data received: {len(data)} bytes - {data.hex()}')
if self.server:
self.server.tcp_data_received(data)
# -----------------------------------------------------------------------------
class TcpServer:
def __init__(self, port):
self.rfcomm_session = None
self.tcp_transport = None
AsyncRunner.spawn(self.run(port))
def attach_session(self, rfcomm_session):
if self.rfcomm_session:
self.rfcomm_session.sink = None
self.rfcomm_session = rfcomm_session
rfcomm_session.sink = self.rfcomm_data_received
def rfcomm_data_received(self, data):
print(f'<<< RFCOMM Data: {data.hex()}')
if self.tcp_transport:
self.tcp_transport.write(data)
else:
print('!!! no TCP connection, dropping data')
def tcp_data_received(self, data):
if self.rfcomm_session:
self.rfcomm_session.write(data)
else:
print('!!! no RFComm session, dropping data')
async def run(self, port):
print(f'$$$ Starting TCP server on port {port}')
server = await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(self), '127.0.0.1', port
)
async with server:
await server.serve_forever()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 4:
print('Usage: run_rfcomm_server.py <device-config> <transport-spec>') print(
print('example: run_rfcomm_server.py classic2.json usb:04b4:f901') 'Usage: run_rfcomm_server.py <device-config> <transport-spec> '
'<tcp-port> [<uuid>]'
)
print('example: run_rfcomm_server.py classic2.json usb:0 8888')
return return
tcp_port = int(sys.argv[3])
if len(sys.argv) >= 5:
uuid = sys.argv[4]
else:
uuid = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -105,15 +131,20 @@ async def main():
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True device.classic_enabled = True
# Create and register a server # Create a TCP server
tcp_server = TcpServer(tcp_port)
# Create and register an RFComm server
rfcomm_server = Server(device) rfcomm_server = Server(device)
# Listen for incoming DLC connections # Listen for incoming DLC connections
channel_number = rfcomm_server.listen(on_dlc) channel_number = rfcomm_server.listen(
print(f'### Listening for connection on channel {channel_number}') lambda session: on_rfcomm_session(session, tcp_server)
)
print(f'### Listening for RFComm connections on channel {channel_number}')
# Setup the SDP to advertise this channel # Setup the SDP to advertise this channel
device.sdp_service_records = sdp_records(channel_number) device.sdp_service_records = sdp_records(channel_number, uuid)
# Start the controller # Start the controller
await device.power_on() await device.power_on()

100
tests/hfp_test.py Normal file
View File

@@ -0,0 +1,100 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import pytest
from typing import Tuple
from .test_utils import TwoDevices
from bumble import hfp
from bumble import rfcomm
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def make_hfp_connections(
hf_config: hfp.Configuration,
) -> Tuple[hfp.HfProtocol, hfp.HfpProtocol]:
# Setup devices
devices = TwoDevices()
await devices.setup_connection()
# Setup RFCOMM channel
wait_dlc = asyncio.get_running_loop().create_future()
rfcomm_channel = rfcomm.Server(devices.devices[0]).listen(
lambda dlc: wait_dlc.set_result(dlc)
)
assert devices.connections[0]
assert devices.connections[1]
client_mux = await rfcomm.Client(devices.devices[1], devices.connections[1]).start()
client_dlc = await client_mux.open_dlc(rfcomm_channel)
server_dlc = await wait_dlc
# Setup HFP connection
hf = hfp.HfProtocol(client_dlc, hf_config)
ag = hfp.HfpProtocol(server_dlc)
return hf, ag
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_slc():
hf_config = hfp.Configuration(
supported_hf_features=[], supported_hf_indicators=[], supported_audio_codecs=[]
)
hf, ag = await make_hfp_connections(hf_config)
async def ag_loop():
while line := await ag.next_line():
if line.startswith('AT+BRSF'):
ag.send_response_line('+BRSF: 0')
elif line.startswith('AT+CIND=?'):
ag.send_response_line(
'+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),'
'("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),'
'("callheld",(0-2))'
)
elif line.startswith('AT+CIND?'):
ag.send_response_line('+CIND: 0,0,1,4,1,5,0')
ag.send_response_line('OK')
ag_task = asyncio.create_task(ag_loop())
await hf.initiate_slc()
ag_task.cancel()
# -----------------------------------------------------------------------------
async def run():
await test_slc()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -18,6 +18,8 @@
import asyncio import asyncio
import json import json
import logging import logging
import pathlib
import pytest
import tempfile import tempfile
import os import os
@@ -83,87 +85,95 @@ JSON3 = """
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_basic(): @pytest.fixture
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file: def temporary_file():
keystore = JsonKeyStore('my_namespace', file.name) file = tempfile.NamedTemporaryFile(delete=False)
file.close()
yield file.name
pathlib.Path(file.name).unlink()
# -----------------------------------------------------------------------------
async def test_basic(temporary_file):
with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write("{}") file.write("{}")
file.flush() file.flush()
keys = await keystore.get_all() keystore = JsonKeyStore('my_namespace', temporary_file)
assert len(keys) == 0
keys = PairingKeys() keys = await keystore.get_all()
await keystore.update('foo', keys) assert len(keys) == 0
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is None
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
file.flush() keys = PairingKeys()
with open(file.name, "r", encoding="utf-8") as json_file: await keystore.update('foo', keys)
json_data = json.load(json_file) foo = await keystore.get('foo')
assert 'my_namespace' in json_data assert foo is not None
assert 'foo' in json_data['my_namespace'] assert foo.ltk is None
assert 'ltk' in json_data['my_namespace']['foo'] ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert 'my_namespace' in json_data
assert 'foo' in json_data['my_namespace']
assert 'ltk' in json_data['my_namespace']['foo']
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_parsing(): async def test_parsing(temporary_file):
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write(JSON1) file.write(JSON1)
file.flush() file.flush()
foo = await keystore.get('14:7D:DA:4E:53:A8/P') keystore = JsonKeyStore('my_namespace', file.name)
assert foo is not None foo = await keystore.get('14:7D:DA:4E:53:A8/P')
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683') assert foo is not None
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_default_namespace(): async def test_default_namespace(temporary_file):
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON1) file.write(JSON1)
file.flush() file.flush()
all_keys = await keystore.get_all() keystore = JsonKeyStore(None, file.name)
assert len(all_keys) == 1 all_keys = await keystore.get_all()
name, keys = all_keys[0] assert len(all_keys) == 1
assert name == '14:7D:DA:4E:53:A8/P' name, keys = all_keys[0]
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON2) file.write(JSON2)
file.flush() file.flush()
keys = PairingKeys() keystore = JsonKeyStore(None, file.name)
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) keys = PairingKeys()
keys.ltk = PairingKeys.Key(ltk) ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
await keystore.update('foo', keys) keys.ltk = PairingKeys.Key(ltk)
file.flush() await keystore.update('foo', keys)
with open(file.name, "r", encoding="utf-8") as json_file: with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file) json_data = json.load(json_file)
assert '__DEFAULT__' in json_data assert '__DEFAULT__' in json_data
assert 'foo' in json_data['__DEFAULT__'] assert 'foo' in json_data['__DEFAULT__']
assert 'ltk' in json_data['__DEFAULT__']['foo'] assert 'ltk' in json_data['__DEFAULT__']['foo']
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON3) file.write(JSON3)
file.flush() file.flush()
all_keys = await keystore.get_all() keystore = JsonKeyStore(None, file.name)
assert len(all_keys) == 1 all_keys = await keystore.get_all()
name, keys = all_keys[0] assert len(all_keys) == 1
assert name == '14:7D:DA:4E:53:A8/P' name, keys = all_keys[0]
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------