Compare commits

..

2 Commits

Author SHA1 Message Date
uael 412fd0f78a pandora: implement L2CAP pandora service
Co-authored-by: Josh Wu <joshwu@google.com>
2023-11-07 00:58:33 -08:00
uael ee494a6543 l2cap: refactor server side to allow deferred accept
In order to avoid any breaking changes this re-impl current APIs with
the exact same behavior.

The previous impl was preventing one to defer the response to an l2cap
channel connection request, both for BR/EDR basic channels and LE credit
based ones. This commit change this to spawn a task on every channel
incoming connection request, then all registered listeners are given a
chance to accept it through a `asyncio.Future`. After a bit of delay, if
none had accepted it, the connection is automatically rejected.
2023-11-07 00:43:02 -08:00
32 changed files with 894 additions and 1463 deletions
+3 -5
View File
@@ -56,7 +56,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
@@ -65,17 +65,15 @@ jobs:
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests
run: cd rust && cargo test-all-features
run: cd rust && cargo test
# At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI
+5 -57
View File
@@ -24,16 +24,10 @@ from prompt_toolkit.shortcuts import PromptSession
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.pairing import OobData, PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.pairing import PairingDelegate, PairingConfig
from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore
from bumble.core import (
AdvertisingData,
ProtocolError,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
)
from bumble.core import ProtocolError
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -66,7 +60,7 @@ class Waiter:
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
io_capability={
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
@@ -292,7 +286,6 @@ async def pair(
bond,
ctkd,
io,
oob,
prompt,
request,
print_keys,
@@ -350,51 +343,16 @@ async def pair(
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
shared_data = (
None
if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
)
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig(
our_context=our_oob_context,
peer_data=shared_data,
legacy_context=legacy_context,
)
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
delegate=Delegate(mode, connection, io, prompt),
sc, mitm, bond, Delegate(mode, connection, io, prompt)
)
# Connect to a peer or wait for a connection
device.on('connection', lambda connection: on_connection(connection, request))
if address_or_name is not None:
print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect(
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
connection = await device.connect(address_or_name)
if not request:
try:
@@ -463,14 +421,6 @@ class LogHandler(logging.Handler):
default='display+keyboard',
show_default=True,
)
@click.option(
'--oob',
metavar='<oob-data-hex>',
help=(
'Use OOB pairing with this data from the peer '
'(use "-" to enable OOB without peer data)'
),
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
@@ -491,7 +441,6 @@ def main(
bond,
ctkd,
io,
oob,
prompt,
request,
print_keys,
@@ -515,7 +464,6 @@ def main(
bond,
ctkd,
io,
oob,
prompt,
request,
print_keys,
-3
View File
@@ -1000,9 +1000,6 @@ class Controller:
'''
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
'''
if self.le_scan_enable:
return bytes([HCI_COMMAND_DISALLOWED_ERROR])
self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval
self.le_scan_window = command.le_scan_window
-11
View File
@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import List, Optional, Tuple, Union, cast, Dict
@@ -1052,13 +1051,3 @@ class ConnectionPHY:
def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
# -----------------------------------------------------------------------------
# LE Role
# -----------------------------------------------------------------------------
class LeRole(enum.IntEnum):
PERIPHERAL_ONLY = 0x00
CENTRAL_ONLY = 0x01
BOTH_PERIPHERAL_PREFERRED = 0x02
BOTH_CENTRAL_PREFERRED = 0x03
-121
View File
@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import collections
import enum
import functools
import logging
import struct
@@ -1371,7 +1370,6 @@ HCI_LE_SUPPORTED_FEATURES_NAMES = {
if feature_name.startswith('HCI_') and feature_name.endswith('_LE_SUPPORTED_FEATURE')
}
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
@@ -1927,9 +1925,6 @@ class HCI_Packet:
if packet_type == HCI_ACL_DATA_PACKET:
return HCI_AclDataPacket.from_bytes(packet)
if packet_type == HCI_SYNCHRONOUS_DATA_PACKET:
return HCI_SynchronousDataPacket.from_bytes(packet)
if packet_type == HCI_EVENT_PACKET:
return HCI_Event.from_bytes(packet)
@@ -2298,19 +2293,6 @@ class HCI_Read_Clock_Offset_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('bd_addr', Address.parse_address),
('reason', {'size': 1, 'mapper': HCI_Constant.error_name}),
],
)
class HCI_Reject_Synchronous_Connection_Request_Command(HCI_Command):
'''
See Bluetooth spec @ 7.1.28 Reject Synchronous Connection Request Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
@@ -2472,51 +2454,6 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_Command):
See Bluetooth spec @ 7.1.45 Enhanced Setup Synchronous Connection Command
'''
class CodingFormat(enum.IntEnum):
U_LOG = 0x00
A_LOG = 0x01
CVSD = 0x02
TRANSPARENT = 0x03
PCM = 0x04
MSBC = 0x05
LC3 = 0x06
G729A = 0x07
def to_bytes(self):
return self.value.to_bytes(5, 'little')
def __bytes__(self):
return self.to_bytes()
class PcmDataFormat(enum.IntEnum):
NA = 0x00
ONES_COMPLEMENT = 0x01
TWOS_COMPLEMENT = 0x02
SIGN_MAGNITUDE = 0x03
UNSIGNED = 0x04
class DataPath(enum.IntEnum):
HCI = 0x00
PCM = 0x01
class RetransmissionEffort(enum.IntEnum):
NO_RETRANSMISSION = 0x00
OPTIMIZE_FOR_POWER = 0x01
OPTIMIZE_FOR_QUALITY = 0x02
DONT_CARE = 0xFF
class PacketType(enum.IntFlag):
HV1 = 0x0001
HV2 = 0x0002
HV3 = 0x0004
EV3 = 0x0008
EV4 = 0x0010
EV5 = 0x0020
NO_2_EV3 = 0x0040
NO_3_EV3 = 0x0080
NO_2_EV5 = 0x0100
NO_3_EV5 = 0x0200
# -----------------------------------------------------------------------------
@HCI_Command.command(
@@ -5801,64 +5738,6 @@ class HCI_AclDataPacket(HCI_Packet):
)
# -----------------------------------------------------------------------------
class HCI_SynchronousDataPacket(HCI_Packet):
'''
See Bluetooth spec @ 5.4.3 HCI SCO Data Packets
'''
hci_packet_type = HCI_SYNCHRONOUS_DATA_PACKET
@staticmethod
def from_bytes(packet: bytes) -> HCI_SynchronousDataPacket:
# Read the header
h, data_total_length = struct.unpack_from('<HB', packet, 1)
connection_handle = h & 0xFFF
packet_status = (h >> 12) & 0b11
rfu = (h >> 14) & 0b11
data = packet[4:]
if len(data) != data_total_length:
raise ValueError(
f'invalid packet length {len(data)} != {data_total_length}'
)
return HCI_SynchronousDataPacket(
connection_handle, packet_status, rfu, data_total_length, data
)
def to_bytes(self) -> bytes:
h = (self.packet_status << 12) | (self.rfu << 14) | self.connection_handle
return (
struct.pack('<BHB', HCI_SYNCHRONOUS_DATA_PACKET, h, self.data_total_length)
+ self.data
)
def __init__(
self,
connection_handle: int,
packet_status: int,
rfu: int,
data_total_length: int,
data: bytes,
) -> None:
self.connection_handle = connection_handle
self.packet_status = packet_status
self.rfu = rfu
self.data_total_length = data_total_length
self.data = data
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self) -> str:
return (
f'{color("SCO", "blue")}: '
f'handle=0x{self.connection_handle:04x}, '
f'ps={self.packet_status}, rfu={self.rfu}, '
f'data_total_length={self.data_total_length}, '
f'data={self.data.hex()}'
)
# -----------------------------------------------------------------------------
class HCI_AclDataPacketAssembler:
current_data: Optional[bytes]
-168
View File
@@ -35,7 +35,6 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.hci import HCI_Enhanced_Setup_Synchronous_Connection_Command
from bumble.sdp import (
DataElement,
ServiceAttribute,
@@ -820,170 +819,3 @@ def sdp_records(
DataElement.unsigned_integer_16(hf_supported_features),
),
]
# -----------------------------------------------------------------------------
# ESCO Codec Default Parameters
# -----------------------------------------------------------------------------
# Hands-Free Profile v1.8, 5.7 Codec Interoperability Requirements
class DefaultCodecParameters(enum.IntEnum):
SCO_CVSD_D0 = enum.auto()
SCO_CVSD_D1 = enum.auto()
ESCO_CVSD_S1 = enum.auto()
ESCO_CVSD_S2 = enum.auto()
ESCO_CVSD_S3 = enum.auto()
ESCO_CVSD_S4 = enum.auto()
ESCO_MSBC_T1 = enum.auto()
ESCO_MSBC_T2 = enum.auto()
@dataclasses.dataclass
class EscoParameters:
# Codec specific
transmit_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat
receive_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat
packet_type: HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType
retransmission_effort: HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort
max_latency: int
# Common
input_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT
)
output_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT
)
input_coded_data_size: int = 16
output_coded_data_size: int = 16
input_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
)
output_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
)
input_pcm_sample_payload_msb_position: int = 0
output_pcm_sample_payload_msb_position: int = 0
input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath.HCI
)
output_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath.HCI
)
input_transport_unit_size: int = 0
output_transport_unit_size: int = 0
input_bandwidth: int = 16000
output_bandwidth: int = 16000
transmit_bandwidth: int = 8000
receive_bandwidth: int = 8000
transmit_codec_frame_size: int = 60
receive_codec_frame_size: int = 60
_ESCO_PARAMETERS_CVSD_D0 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0xFFFF,
packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV1,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
)
_ESCO_PARAMETERS_CVSD_D1 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0xFFFF,
packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV3,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
)
_ESCO_PARAMETERS_CVSD_S1 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0x0007,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S2 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0x0007,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S3 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0x000A,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S4 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
max_latency=0x000C,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
_ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
max_latency=0x0008,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
_ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
max_latency=0x000D,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
ESCO_PERAMETERS = {
DefaultCodecParameters.SCO_CVSD_D0: _ESCO_PARAMETERS_CVSD_D0,
DefaultCodecParameters.SCO_CVSD_D1: _ESCO_PARAMETERS_CVSD_D1,
DefaultCodecParameters.ESCO_CVSD_S1: _ESCO_PARAMETERS_CVSD_S1,
DefaultCodecParameters.ESCO_CVSD_S2: _ESCO_PARAMETERS_CVSD_S2,
DefaultCodecParameters.ESCO_CVSD_S3: _ESCO_PARAMETERS_CVSD_S3,
DefaultCodecParameters.ESCO_CVSD_S4: _ESCO_PARAMETERS_CVSD_S4,
DefaultCodecParameters.ESCO_MSBC_T1: _ESCO_PARAMETERS_MSBC_T1,
DefaultCodecParameters.ESCO_MSBC_T2: _ESCO_PARAMETERS_MSBC_T2,
}
+8 -34
View File
@@ -21,7 +21,7 @@ import collections
import logging
import struct
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
@@ -43,7 +43,6 @@ from .hci import (
HCI_RESET_COMMAND,
HCI_SUCCESS,
HCI_SUPPORTED_COMMANDS_FLAGS,
HCI_SYNCHRONOUS_DATA_PACKET,
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
@@ -68,7 +67,6 @@ from .hci import (
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
HCI_SynchronousDataPacket,
)
from .core import (
BT_BR_EDR_TRANSPORT,
@@ -487,14 +485,12 @@ class Host(AbortableEventEmitter):
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:
self.on_hci_command_packet(cast(HCI_Command, packet))
elif packet.hci_packet_type == HCI_EVENT_PACKET:
self.on_hci_event_packet(cast(HCI_Event, packet))
elif packet.hci_packet_type == HCI_ACL_DATA_PACKET:
self.on_hci_acl_data_packet(cast(HCI_AclDataPacket, packet))
elif packet.hci_packet_type == HCI_SYNCHRONOUS_DATA_PACKET:
self.on_hci_sco_data_packet(cast(HCI_SynchronousDataPacket, packet))
if isinstance(packet, HCI_Command):
self.on_hci_command_packet(packet)
elif isinstance(packet, HCI_Event):
self.on_hci_event_packet(packet)
elif isinstance(packet, HCI_AclDataPacket):
self.on_hci_acl_data_packet(packet)
else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
@@ -511,10 +507,6 @@ class Host(AbortableEventEmitter):
if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet)
def on_hci_sco_data_packet(self, packet: HCI_SynchronousDataPacket) -> None:
# Experimental
self.emit('sco_packet', packet.connection_handle, packet)
def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu)
@@ -768,25 +760,7 @@ class Host(AbortableEventEmitter):
asyncio.create_task(send_long_term_key())
def on_hci_synchronous_connection_complete_event(self, event):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### SCO CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
# Notify the client
self.emit(
'sco_connection',
event.bd_addr,
event.connection_handle,
event.link_type,
)
else:
logger.debug(f'### SCO CONNECTION FAILED: {event.status}')
# Notify the client
self.emit('sco_connection_failure', event.bd_addr, event.status)
pass
def on_hci_synchronous_connection_changed_event(self, event):
pass
+389 -209
View File
@@ -35,8 +35,10 @@ from typing import (
Union,
Deque,
Iterable,
Set,
SupportsBytes,
TYPE_CHECKING,
overload,
)
from .utils import deprecated
@@ -237,6 +239,8 @@ class L2CAP_Control_Frame:
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name: str
identifier: int
pdu: bytes
@staticmethod
def from_bytes(pdu: bytes) -> L2CAP_Control_Frame:
@@ -391,6 +395,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
psm: int
source_cid: int
@staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
@@ -637,7 +644,11 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
(CODE 0x14)
'''
le_psm: int
source_cid: int
mtu: int
mps: int
initial_credits: int
# -----------------------------------------------------------------------------
@@ -1375,19 +1386,14 @@ class LeCreditBasedChannel(EventEmitter):
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[ClassicChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
@@ -1395,28 +1401,18 @@ class ClassicChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
_close_closure: Callable[[], None]
psm: int
handler: Optional[Callable[[LeCreditBasedChannel], Any]]
def __post_init__(self) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
@@ -1424,21 +1420,107 @@ class LeCreditBasedChannelServer(EventEmitter):
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
self._close_closure()
# -----------------------------------------------------------------------------
class PendingConnection:
"""
All pending connection types.
A `PendingConnection` is a temporary object used to accept an incoming connection
request, it contains the acceptor channel configuration preferences and transition
to the connected state through the `on_connection` callback.
This object is not supposed to live anymore once the channel is connected.
"""
class Any:
"""L2CAP any channel pending connection."""
on_connection: Callable[[Any], None]
mtu: int
@dataclasses.dataclass
class Basic(Any):
"""L2CAP basic channel pending connection."""
on_connection: Callable[[ClassicChannel], None] = lambda _: None
mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP LE credit based channel pending connection."""
on_connection: Callable[[LeCreditBasedChannel], None] = lambda _: None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
class IncomingConnection:
"""
All incoming connection types.
A `IncomingConnection` is a temporary object used to notify listeners of an
incoming channel connection request. It can accepted through the `future` field.
Multiple listeners can observe the same incoming connection request, but no more
than one can actually accept, first come first served. Thus it's recommended for
delayed accept to before check the state of the future field.
This object is not supposed to live anymore once accepted.
Example:
```python
fut = asyncio.Future()
def listener(incoming: IncomingConnection.Any) -> None:
if isinstance(incoming, IncomingConnection.Basic) and incoming.psm == 0xcafe:
incoming.future.set_result(PendingConnection.Basic(fut.set_result, mtu=123))
device.l2cap_manager.listen(listener)
channel = await fut
```
"""
@dataclasses.dataclass
class Any:
"""L2CAP any incoming channel connection request."""
connection: Connection
psm: int
source_cid: int
def __post_init__(self) -> None:
self.future: asyncio.Future[Any] = asyncio.Future()
@dataclasses.dataclass
class Basic(Any):
"""L2CAP incoming basic channel connection request."""
future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False)
@dataclasses.dataclass
class LeCreditBased(Any):
"""L2CAP incoming LE credit based channel connection request."""
mtu: int
mps: int
initial_credits: int
future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field(
init=False
)
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]]
listeners: List[Callable[[IncomingConnection.Any], None]]
used_psm: Set[int]
def __init__(
self,
@@ -1452,15 +1534,15 @@ class ChannelManager:
L2CAP_SIGNALING_CID: None,
L2CAP_LE_SIGNALING_CID: None,
}
self.servers = {} # Servers accepting connections, by PSM
self.le_coc_channels = (
{}
) # LE CoC channels, mapped by connection and destination cid
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None
self.listeners = []
self.used_psm = set()
@property
def host(self) -> Host:
@@ -1513,6 +1595,31 @@ class ChannelManager:
raise RuntimeError('no free CID')
def allocate_psm(self) -> int:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def allocate_spsm(self) -> int:
# Find a free sPSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.used_psm:
continue
return candidate
raise InvalidStateError('no free PSM')
def free_psm(self, psm: int) -> None:
self.used_psm.remove(psm)
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -1527,6 +1634,35 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@overload
def listen(
self, cb: Callable[[IncomingConnection.Basic], None]
) -> Callable[[IncomingConnection.Basic], None]:
...
@overload
def listen(
self, cb: Callable[[IncomingConnection.LeCreditBased], None]
) -> Callable[[IncomingConnection.LeCreditBased], None]:
...
def listen(self, cb: Any) -> Any:
if cb in self.listeners:
raise ValueError('listener already registered')
self.listeners.append(cb)
return cb
@overload
def unlisten(self, cb: Callable[[IncomingConnection.Basic], None]) -> None:
...
@overload
def unlisten(self, cb: Callable[[IncomingConnection.LeCreditBased], None]) -> None:
...
def unlisten(self, cb: Any) -> None:
self.listeners.remove(cb)
@deprecated("Please use create_classic_server")
def register_server(
self,
@@ -1534,7 +1670,7 @@ class ChannelManager:
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
handler=server, spec=ClassicChannelSpec(psm=None if psm == 0 else psm)
).psm
def create_classic_server(
@@ -1542,24 +1678,12 @@ class ChannelManager:
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: ClassicChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
@@ -1568,10 +1692,22 @@ class ChannelManager:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.used_psm.add(spec.psm)
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
def listener(incoming: IncomingConnection.Basic) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.Basic(server.on_connection, spec.mtu)
)
return self.servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = ClassicChannelServer(close, spec.psm, handler)
return server
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
@@ -1594,32 +1730,30 @@ class ChannelManager:
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if not spec.psm:
# Find a free PSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.le_coc_servers:
continue
spec.psm = candidate
break
else:
raise InvalidStateError('no free PSM')
server: LeCreditBasedChannelServer
if spec.psm is None:
spec.psm = self.allocate_psm()
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use')
if spec.psm is self.used_psm:
raise ValueError(f'{spec.psm}: SPSM already in use')
self.used_psm.add(spec.psm)
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
spec.psm,
handler,
max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
)
def listener(incoming: IncomingConnection.LeCreditBased) -> None:
if incoming.psm == spec.psm:
incoming.future.set_result(
PendingConnection.LeCreditBased(
server.on_connection, spec.mtu, spec.mps, spec.max_credits
)
)
return self.le_coc_servers[spec.psm]
def close() -> None:
self.unlisten(listener)
assert spec.psm is not None
self.free_psm(spec.psm)
self.listen(listener)
server = LeCreditBasedChannelServer(close, spec.psm, handler)
return server
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1719,15 +1853,62 @@ class ChannelManager:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(
self, connection: Connection, cid: int, request
self, connection: Connection, cid: int, request: L2CAP_Connection_Request
) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.Basic(
connection, request.psm, request.source_cid
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, pending.mtu
)
connection_channels[source_cid] = channel
# Notify
pending.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
@@ -1736,41 +1917,13 @@ class ChannelManager:
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
return
# Create a new channel
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.mtu
)
connection_channels[source_cid] = channel
# Notify
server.on_connection(channel)
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_connection_response(
self, connection: Connection, cid: int, response
@@ -1971,108 +2124,135 @@ class ChannelManager:
)
def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request
self,
connection: Connection,
cid: int,
request: L2CAP_LE_Credit_Based_Connection_Request,
) -> None:
if request.le_psm in self.le_coc_servers:
server = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
# Asynchronous connection request handling.
async def handle_connection_request() -> None:
incoming = IncomingConnection.LeCreditBased(
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
request.initial_credits,
)
# Notify
server.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Dispatch incoming connection.
for listener in self.listeners:
if not incoming.future.done():
listener(incoming)
try:
pending = await asyncio.wait_for(incoming.future, timeout=3.0)
except asyncio.TimeoutError as e:
incoming.future.cancel(e)
pending = None
if pending:
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
pending.mtu,
pending.mps,
request.initial_credits,
request.mtu,
request.mps,
pending.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=pending.mtu,
mps=pending.mps,
initial_credits=pending.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
# Notify
pending.on_connection(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
# Spawn connection request handling.
connection.abort_on('disconnection', handle_connection_request())
def on_l2cap_le_credit_based_connection_response(
self, connection: Connection, _cid: int, response
+1 -67
View File
@@ -15,9 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Optional, Tuple
from .hci import (
@@ -37,60 +35,7 @@ from .smp import (
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
)
from .core import AdvertisingData, LeRole
# -----------------------------------------------------------------------------
@dataclass
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
shared_data_c = ad_data
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
shared_data_r = ad_data
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
instance.legacy_context = OobLegacyContext(tk=ad_data)
if shared_data_c and shared_data_r:
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
return instance
def to_ad(self) -> AdvertisingData:
ad_structures = []
if self.address is not None:
ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
)
if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None:
ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
)
return AdvertisingData(ad_structures)
# -----------------------------------------------------------------------------
@@ -228,14 +173,6 @@ class PairingConfig:
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__(
self,
sc: bool = True,
@@ -243,20 +180,17 @@ class PairingConfig:
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
self.oob = oob
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}]), '
f'oob[{self.oob}])'
f'delegate[{self.delegate.io_capability}])'
)
+3
View File
@@ -26,11 +26,13 @@ from .config import Config
from .device import PandoraDevice
from .host import HostService
from .security import SecurityService, SecurityStorageService
from .l2cap import L2CAPService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from typing import Callable, List, Optional
# public symbols
@@ -77,6 +79,7 @@ async def serve(
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
# call hooks if any.
for hook in _SERVICERS_HOOKS:
+289
View File
@@ -0,0 +1,289 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import grpc
import struct
from bumble import device
from bumble import l2cap
from bumble.pandora import config
from bumble.pandora import utils
from bumble.utils import EventWatcher
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora import l2cap_pb2
from pandora import l2cap_grpc_aio
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Union
@dataclasses.dataclass
class ChannelProxy:
channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel, None]
def __post_init__(self) -> None:
assert self.channel
self.rx: asyncio.Queue[bytes] = asyncio.Queue()
self._disconnection_result: asyncio.Future[None] = asyncio.Future()
self.channel.sink = self.rx.put_nowait
def on_close() -> None:
assert not self._disconnection_result.done()
self.channel = None
self._disconnection_result.set_result(None)
self.channel.on('close', on_close)
def send(self, data: bytes) -> None:
assert self.channel
if isinstance(self.channel, l2cap.ClassicChannel):
self.channel.send_pdu(data)
else:
self.channel.write(data)
async def disconnect(self) -> None:
assert self.channel
await self.channel.disconnect()
async def wait_disconnect(self) -> None:
await self._disconnection_result
assert not self.channel
@dataclasses.dataclass
class ChannelIndex:
connection_handle: int
cid: int
@classmethod
def from_token(cls, token: l2cap_pb2.Channel) -> 'ChannelIndex':
connection_handle, cid = struct.unpack('>HH', token.cookie.value)
return cls(connection_handle, cid)
def into_token(self) -> l2cap_pb2.Channel:
return l2cap_pb2.Channel(
cookie=any_pb2.Any(
value=struct.pack('>HH', self.connection_handle, self.cid)
)
)
def __hash__(self):
return hash(self.connection_handle | (self.cid << 12))
class L2CAPService(l2cap_grpc_aio.L2CAPServicer):
channels: Dict[ChannelIndex, ChannelProxy] = {}
pending: List[l2cap.IncomingConnection.Any] = []
accepts: List[asyncio.Queue[l2cap.IncomingConnection.Any]] = []
def __init__(self, dev: device.Device, config: config.Config) -> None:
self.device = dev
self.config = config
def on_connection(incoming: l2cap.IncomingConnection.Any) -> None:
self.pending.append(incoming)
for acceptor in self.accepts:
acceptor.put_nowait(incoming)
# Make sure our listener is called before the builtins ones.
self.device.l2cap_channel_manager.listeners.insert(0, on_connection)
def register(self, index: ChannelIndex, proxy: ChannelProxy) -> None:
self.channels[index] = proxy
def on_close(*_: Any) -> None:
# TODO: Fix Bumble L2CAP which emit `close` event twice.
if index in self.channels:
del self.channels[index]
# Listen for disconnection.
assert proxy.channel
proxy.channel.on('close', on_close)
async def listen(self) -> AsyncIterator[l2cap.IncomingConnection.Any]:
for incoming in self.pending:
if incoming.future.done():
self.pending.remove(incoming)
continue
yield incoming
queue: asyncio.Queue[l2cap.IncomingConnection.Any] = asyncio.Queue()
self.accepts.append(queue)
try:
while incoming := await queue.get():
yield incoming
finally:
self.accepts.remove(queue)
@utils.rpc
async def Connect(
self, request: l2cap_pb2.ConnectRequest, context: grpc.ServicerContext
) -> l2cap_pb2.ConnectResponse:
# Retrieve Bumble `Connection` from request.
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if connection is None:
raise RuntimeError(f'{connection_handle}: not connection for handle')
channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]
if request.type_variant() == 'basic':
assert request.basic
channel = await connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(
psm=request.basic.psm, mtu=request.basic.mtu
)
)
elif request.type_variant() == 'le_credit_based':
assert request.le_credit_based
channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=request.le_credit_based.spsm,
max_credits=request.le_credit_based.initial_credit,
mtu=request.le_credit_based.mtu,
mps=request.le_credit_based.mps,
)
)
else:
raise NotImplementedError(f"{request.type_variant()}: unsupported type")
index = ChannelIndex(channel.connection.handle, channel.source_cid)
self.register(index, ChannelProxy(channel))
return l2cap_pb2.ConnectResponse(channel=index.into_token())
@utils.rpc
async def WaitConnection(
self, request: l2cap_pb2.WaitConnectionRequest, context: grpc.ServicerContext
) -> l2cap_pb2.WaitConnectionResponse:
iter = self.listen()
fut: asyncio.Future[
Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]
] = asyncio.Future()
# Filter by connection.
if request.connection:
handle = int.from_bytes(request.connection.cookie.value, 'big')
iter = (it async for it in iter if it.connection.handle == handle)
if request.type_variant() == 'basic':
assert request.basic
basic = l2cap.PendingConnection.Basic(
fut.set_result,
request.basic.mtu or l2cap.L2CAP_MIN_BR_EDR_MTU,
)
async for i in (
it
async for it in iter
if isinstance(it, l2cap.IncomingConnection.Basic)
):
if not i.future.done() and i.psm == request.basic.psm:
i.future.set_result(basic)
break
elif request.type_variant() == 'le_credit_based':
assert request.le_credit_based
le_credit_based = l2cap.PendingConnection.LeCreditBased(
fut.set_result,
request.le_credit_based.mtu
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
request.le_credit_based.mps
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
request.le_credit_based.initial_credit
or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
)
async for j in (
it
async for it in iter
if isinstance(it, l2cap.IncomingConnection.LeCreditBased)
):
if not j.future.done() and j.psm == request.le_credit_based.spsm:
j.future.set_result(le_credit_based)
break
else:
raise NotImplementedError(f"{request.type_variant()}: unsupported type")
channel = await fut
index = ChannelIndex(channel.connection.handle, channel.source_cid)
self.register(index, ChannelProxy(channel))
return l2cap_pb2.WaitConnectionResponse(channel=index.into_token())
@utils.rpc
async def Disconnect(
self, request: l2cap_pb2.DisconnectRequest, context: grpc.ServicerContext
) -> l2cap_pb2.DisconnectResponse:
channel = self.channels[ChannelIndex.from_token(request.channel)]
await channel.disconnect()
return l2cap_pb2.DisconnectResponse(success=empty_pb2.Empty())
@utils.rpc
async def WaitDisconnection(
self, request: l2cap_pb2.WaitDisconnectionRequest, context: grpc.ServicerContext
) -> l2cap_pb2.WaitDisconnectionResponse:
channel = self.channels[ChannelIndex.from_token(request.channel)]
await channel.wait_disconnect()
return l2cap_pb2.WaitDisconnectionResponse(success=empty_pb2.Empty())
@utils.rpc
async def Receive(
self, request: l2cap_pb2.ReceiveRequest, context: grpc.ServicerContext
) -> AsyncGenerator[l2cap_pb2.ReceiveResponse, None]:
watcher = EventWatcher()
if request.source_variant() == 'channel':
assert request.channel
channel = self.channels[ChannelIndex.from_token(request.channel)]
rx = channel.rx
elif request.source_variant() == 'fixed_channel':
assert request.fixed_channel
rx = asyncio.Queue()
handle = request.fixed_channel.connection is not None and int.from_bytes(
request.fixed_channel.connection.cookie.value, 'big'
)
@watcher.on(self.device.host, 'l2cap_pdu')
def _(connection: device.Connection, cid: int, pdu: bytes) -> None:
assert request.fixed_channel
if cid == request.fixed_channel.cid and (
handle is None or handle == connection.handle
):
rx.put_nowait(pdu)
else:
raise NotImplementedError(f"{request.source_variant()}: unsupported type")
try:
while data := await rx.get():
yield l2cap_pb2.ReceiveResponse(data=data)
finally:
watcher.close()
@utils.rpc
async def Send(
self, request: l2cap_pb2.SendRequest, context: grpc.ServicerContext
) -> l2cap_pb2.SendResponse:
if request.sink_variant() == 'channel':
assert request.channel
channel = self.channels[ChannelIndex.from_token(request.channel)]
channel.send(request.data)
elif request.sink_variant() == 'fixed_channel':
assert request.fixed_channel
# Retrieve Bumble `Connection` from request.
connection_handle = int.from_bytes(
request.fixed_channel.connection.cookie.value, 'big'
)
connection = self.device.lookup_connection(connection_handle)
if connection is None:
raise RuntimeError(f'{connection_handle}: not connection for handle')
self.device.l2cap_channel_manager.send_pdu(
connection, request.fixed_channel.cid, request.data
)
else:
raise NotImplementedError(f"{request.sink_variant()}: unsupported type")
return l2cap_pb2.SendResponse(success=empty_pb2.Empty())
+33 -156
View File
@@ -27,7 +27,6 @@ import logging
import asyncio
import enum
import secrets
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
@@ -54,7 +53,6 @@ from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData,
ProtocolError,
name_or_number,
)
@@ -565,54 +563,6 @@ class PairingMethod(enum.IntEnum):
CTKD_OVER_CLASSIC = 4
# -----------------------------------------------------------------------------
class OobContext:
"""Cryptographic context for LE SC OOB pairing."""
ecc_key: crypto.EccKey
r: bytes
def __init__(
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
def share(self) -> OobSharedData:
pkx = bytes(reversed(self.ecc_key.x))
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
# -----------------------------------------------------------------------------
class OobLegacyContext:
"""Cryptographic context for LE Legacy OOB pairing."""
tk: bytes
def __init__(self, tk: Optional[bytes] = None) -> None:
self.tk = crypto.r() if tk is None else tk
# -----------------------------------------------------------------------------
@dataclass
class OobSharedData:
"""Shareable data for LE SC OOB pairing."""
c: bytes
r: bytes
def to_ad(self) -> AdvertisingData:
return AdvertisingData(
[
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
]
)
def __str__(self) -> str:
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
# -----------------------------------------------------------------------------
class Session:
# I/O Capability to pairing method decision matrix
@@ -690,6 +640,8 @@ class Session:
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
self.r = bytes(16)
self.stk = None
self.ltk = None
self.ltk_ediv = 0
@@ -707,7 +659,7 @@ class Session:
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = b''
self.dh_key = None
self.confirm_value = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
@@ -760,8 +712,8 @@ class Session:
self.io_capability = pairing_config.delegate.io_capability
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# OOB (not supported yet)
self.oob = False
# Set up addresses
self_address = connection.self_address
@@ -777,37 +729,9 @@ class Session:
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
# Select the ECC key, TK and r initial value
if pairing_config.oob:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise ValueError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
self.ecc_key = pairing_config.oob.our_context.ecc_key
if pairing_config.oob.legacy_context is None:
self.tk = None
else:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise ValueError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = pairing_config.oob.legacy_context.tk
else:
self.peer_oob_data = None
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = bytes(16)
@property
def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.ecc_key.x)), self.peer_public_key_x)
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property
def pka(self) -> bytes:
@@ -844,10 +768,7 @@ class Session:
return None
def decide_pairing_method(
self,
auth_req: int,
initiator_io_capability: int,
responder_io_capability: int,
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
@@ -988,7 +909,7 @@ class Session:
command = SMP_Pairing_Request_Command(
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
oob_data_flag=0,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -1000,7 +921,7 @@ class Session:
def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command(
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
oob_data_flag=0,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -1061,8 +982,8 @@ class Session:
def send_public_key_command(self) -> None:
self.send_command(
SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.ecc_key.x)),
public_key_y=bytes(reversed(self.ecc_key.y)),
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
)
)
@@ -1109,6 +1030,7 @@ class Session:
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1374,7 +1296,7 @@ class Session:
try:
handler(command)
except Exception as error:
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
@@ -1411,28 +1333,15 @@ class Session:
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req,
command.io_capability,
self.io_capability,
)
# Check for OOB
if command.oob_data_flag != 0:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, command.io_capability, self.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
@@ -1481,26 +1390,15 @@ class Session:
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
# Check for OOB
if self.sc and command.oob_data_flag:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
@@ -1651,13 +1549,12 @@ class Session:
if self.passkey_step < 20:
self.send_pairing_confirm_command()
return
elif self.pairing_method != PairingMethod.OOB:
else:
return
else:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
self.send_pairing_random_command()
elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1694,7 +1591,6 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
@@ -1703,6 +1599,7 @@ class Session:
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
else:
# OOB not implemented yet
return
assert self.preq and self.pres
@@ -1756,7 +1653,7 @@ class Session:
# Compute the DH key
self.dh_key = bytes(
reversed(
self.ecc_key.dh(
self.manager.ecc_key.dh(
bytes(reversed(command.public_key_x)),
bytes(reversed(command.public_key_y)),
)
@@ -1764,27 +1661,8 @@ class Session:
)
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.pairing_method == PairingMethod.OOB:
# Check against shared OOB data
if self.peer_oob_data:
confirm_verifier = crypto.f4(
self.peer_public_key_x,
self.peer_public_key_x,
self.peer_oob_data.r,
bytes(1),
)
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
):
return
if self.is_initiator:
if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command()
else:
self.send_pairing_confirm_command()
self.send_pairing_confirm_command()
else:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
@@ -1795,7 +1673,6 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
# We can now send the confirmation value
self.send_pairing_confirm_command()
-8
View File
@@ -31,7 +31,6 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT,
)
from bumble import rfcomm, hfp
from bumble.hci import HCI_SynchronousDataPacket
from bumble.sdp import (
Client as SDP_Client,
DataElement,
@@ -198,13 +197,6 @@ async def main():
print('@@@ Disconnected from RFCOMM server')
return
def on_sco(connection_handle: int, packet: HCI_SynchronousDataPacket):
# Reset packet and loopback
packet.packet_status = 0
device.host.send_hci_packet(packet)
device.host.on('sco_packet', on_sco)
# Protocol loop (just for testing at this point)
protocol = hfp.HfpProtocol(session)
while True:
+2 -19
View File
@@ -12,13 +12,6 @@ keywords = ["bluetooth", "ble"]
categories = ["api-bindings", "network-programming"]
rust-version = "1.70.0"
# https://github.com/frewsxcv/cargo-all-features#options
[package.metadata.cargo-all-features]
# We are interested in testing subset combinations of this feature, so this is redundant
denylist = ["unstable"]
# To exercise combinations of any of these features, remove from `always_include_features`
always_include_features = ["anyhow", "pyo3-asyncio-attributes", "dev-tools", "bumble-tools"]
[dependencies]
pyo3 = { version = "0.18.3", features = ["macros"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] }
@@ -33,7 +26,6 @@ thiserror = "1.0.41"
bytes = "1.5.0"
pdl-derive = "0.2.0"
pdl-runtime = "0.2.0"
futures = "0.3.28"
# Dev tools
file-header = { version = "0.1.2", optional = true }
@@ -44,6 +36,7 @@ anyhow = { version = "1.0.71", optional = true }
clap = { version = "4.3.3", features = ["derive"], optional = true }
directories = { version = "5.0.1", optional = true }
env_logger = { version = "0.10.0", optional = true }
futures = { version = "0.3.28", optional = true }
log = { version = "0.4.19", optional = true }
owo-colors = { version = "3.5.0", optional = true }
reqwest = { version = "0.11.20", features = ["blocking"], optional = true }
@@ -81,11 +74,6 @@ name = "bumble"
path = "src/main.rs"
required-features = ["bumble-tools"]
[[example]]
name = "broadcast"
path = "examples/broadcast.rs"
required-features = ["unstable_extended_adv"]
# test entry point that uses pyo3_asyncio's test harness
[[test]]
name = "pytests"
@@ -97,10 +85,5 @@ anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
# separate feature for CLI so that dependencies don't spend time building these
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger"]
# all the unstable features
unstable = ["unstable_extended_adv"]
unstable_extended_adv = []
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
default = []
+6 -3
View File
@@ -33,7 +33,6 @@
use bumble::wrapper::{
device::{Device, Peer},
hci::{packets::AddressType, Address},
profile::BatteryServiceProxy,
transport::Transport,
PyObjectExt,
@@ -53,8 +52,12 @@ async fn main() -> PyResult<()> {
let transport = Transport::open(cli.transport).await?;
let address = Address::new("F0:F1:F2:F3:F4:F5", AddressType::RandomDeviceAddress)?;
let device = Device::with_hci("Bumble", address, transport.source()?, transport.sink()?)?;
let device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
device.power_on().await?;
+5 -21
View File
@@ -63,28 +63,17 @@ async fn main() -> PyResult<()> {
)
.map_err(|e| anyhow!(e))?;
device.set_advertising_data(adv_data)?;
device.power_on().await?;
if cli.extended {
println!("Starting extended advertisement...");
device.start_advertising_extended(adv_data).await?;
} else {
device.set_advertising_data(adv_data)?;
println!("Starting legacy advertisement...");
device.start_advertising(true).await?;
}
println!("Advertising...");
device.start_advertising(true).await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
if cli.extended {
println!("Stopping extended advertisement...");
device.stop_advertising_extended().await?;
} else {
println!("Stopping legacy advertisement...");
device.stop_advertising().await?;
}
println!("Stopping...");
device.stop_advertising().await?;
Ok(())
}
@@ -97,17 +86,12 @@ struct Cli {
/// See, for instance, `examples/device1.json` in the Python project.
#[arg(long)]
device_config: path::PathBuf,
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Whether to perform an extended (BT 5.0) advertisement
#[arg(long)]
extended: bool,
/// Log HCI commands
#[arg(long)]
log_hci: bool,
+7 -5
View File
@@ -20,9 +20,7 @@
use bumble::{
adv::CommonDataType,
wrapper::{
core::AdvertisementDataUnit,
device::Device,
hci::{packets::AddressType, Address},
core::AdvertisementDataUnit, device::Device, hci::packets::AddressType,
transport::Transport,
},
};
@@ -46,8 +44,12 @@ async fn main() -> PyResult<()> {
let transport = Transport::open(cli.transport).await?;
let address = Address::new("F0:F1:F2:F3:F4:F5", AddressType::RandomDeviceAddress)?;
let mut device = Device::with_hci("Bumble", address, transport.source()?, transport.sink()?)?;
let mut device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
// in practice, devices can send multiple advertisements from the same address, so we keep
// track of a timestamp for each set of data
+77
View File
@@ -0,0 +1,77 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::{
controller::Controller,
device::Device,
drivers::rtk::DriverInfo,
hci::{
packets::{
AddressType, ErrorCode, ReadLocalVersionInformationBuilder,
ReadLocalVersionInformationComplete,
},
Address, Error,
},
host::Host,
link::Link,
transport::Transport,
};
use nix::sys::stat::Mode;
use pyo3::{
exceptions::PyException,
{PyErr, PyResult},
};
#[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> {
let dir = tempfile::tempdir().unwrap();
let mut fifo = dir.path().to_path_buf();
fifo.push("bumble-transport-fifo");
nix::unistd::mkfifo(&fifo, Mode::S_IRWXU).unwrap();
let mut t = Transport::open(format!("file:{}", fifo.to_str().unwrap())).await?;
t.close().await?;
Ok(())
}
#[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len());
Ok(())
}
#[pyo3_asyncio::tokio::test]
async fn hci_command_wrapper_has_correct_methods() -> PyResult<()> {
let address = Address::new("F0:F1:F2:F3:F4:F5", &AddressType::RandomDeviceAddress)?;
let link = Link::new_local_link()?;
let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?;
let host = Host::new(controller.clone().into(), controller.into()).await?;
let device = Device::new(None, Some(address), None, Some(host), None)?;
device.power_on().await?;
// Send some simple command. A successful response means [HciCommandWrapper] has the minimum
// required interface for the Python code to think its an [HCI_Command] object.
let command = ReadLocalVersionInformationBuilder {};
let event: ReadLocalVersionInformationComplete = device
.send_command(&command.into(), true)
.await?
.try_into()
.map_err(|e: Error| PyErr::new::<PyException, _>(e.to_string()))?;
assert_eq!(ErrorCode::Success, event.get_status());
Ok(())
}
-22
View File
@@ -1,22 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::drivers::rtk::DriverInfo;
use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len());
Ok(())
}
-86
View File
@@ -1,86 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::{
controller::Controller,
device::Device,
hci::{
packets::{
AddressType, Enable, ErrorCode, LeScanType, LeScanningFilterPolicy,
LeSetScanEnableBuilder, LeSetScanEnableComplete, LeSetScanParametersBuilder,
LeSetScanParametersComplete, OwnAddressType,
},
Address, Error,
},
host::Host,
link::Link,
};
use pyo3::{
exceptions::PyException,
{PyErr, PyResult},
};
#[pyo3_asyncio::tokio::test]
async fn test_hci_roundtrip_success_and_failure() -> PyResult<()> {
let address = Address::new("F0:F1:F2:F3:F4:F5", AddressType::RandomDeviceAddress)?;
let device = create_local_device(address).await?;
device.power_on().await?;
// BLE Spec Core v5.3
// 7.8.9 LE Set Scan Parameters command
// ...
// The Host shall not issue this command when scanning is enabled in the
// Controller; if it is the Command Disallowed error code shall be used.
// ...
let command = LeSetScanEnableBuilder {
filter_duplicates: Enable::Disabled,
// will cause failure later
le_scan_enable: Enable::Enabled,
};
let event: LeSetScanEnableComplete = device
.send_command(command.into(), false)
.await?
.try_into()
.map_err(|e: Error| PyErr::new::<PyException, _>(e.to_string()))?;
assert_eq!(ErrorCode::Success, event.get_status());
let command = LeSetScanParametersBuilder {
le_scan_type: LeScanType::Passive,
le_scan_interval: 0,
le_scan_window: 0,
own_address_type: OwnAddressType::RandomDeviceAddress,
scanning_filter_policy: LeScanningFilterPolicy::AcceptAll,
};
let event: LeSetScanParametersComplete = device
.send_command(command.into(), false)
.await?
.try_into()
.map_err(|e: Error| PyErr::new::<PyException, _>(e.to_string()))?;
assert_eq!(ErrorCode::CommandDisallowed, event.get_status());
Ok(())
}
async fn create_local_device(address: Address) -> PyResult<Device> {
let link = Link::new_local_link()?;
let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?;
let host = Host::new(controller.clone().into(), controller.into()).await?;
Device::new(None, Some(address), None, Some(host), None)
}
-17
View File
@@ -1,17 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod drivers;
mod hci;
mod transport;
-31
View File
@@ -1,31 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::transport::Transport;
use nix::sys::stat::Mode;
use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> {
let dir = tempfile::tempdir().unwrap();
let mut fifo = dir.path().to_path_buf();
fifo.push("bumble-transport-fifo");
nix::unistd::mkfifo(&fifo, Mode::S_IRWXU).unwrap();
let mut t = Transport::open(format!("file:{}", fifo.to_str().unwrap())).await?;
t.close().await?;
Ok(())
}
+5 -7
View File
@@ -94,7 +94,7 @@ impl From<Error> for PacketTypeParseError {
impl WithPacketType<Self> for Command {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Command, self)
prepend_packet_type(PacketType::Command, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
@@ -104,7 +104,7 @@ impl WithPacketType<Self> for Command {
impl WithPacketType<Self> for Acl {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Acl, self)
prepend_packet_type(PacketType::Acl, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
@@ -114,7 +114,7 @@ impl WithPacketType<Self> for Acl {
impl WithPacketType<Self> for Sco {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Sco, self)
prepend_packet_type(PacketType::Sco, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
@@ -124,7 +124,7 @@ impl WithPacketType<Self> for Sco {
impl WithPacketType<Self> for Event {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Event, self)
prepend_packet_type(PacketType::Event, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
@@ -132,9 +132,7 @@ impl WithPacketType<Self> for Event {
}
}
fn prepend_packet_type<T: Packet>(packet_type: PacketType, packet: T) -> Vec<u8> {
// TODO: refactor if `pdl` crate adds API for writing into buffer (github.com/google/pdl/issues/74)
let mut packet_bytes = packet.to_vec();
fn prepend_packet_type(packet_type: PacketType, mut packet_bytes: Vec<u8>) -> Vec<u8> {
packet_bytes.insert(0, packet_type.into());
packet_bytes
}
+7 -10
View File
@@ -22,8 +22,9 @@ use bytes::Bytes;
#[test]
fn prepends_packet_type() {
let packet_type = PacketType::Event;
let actual = prepend_packet_type(packet_type, FakePacket { bytes: vec![0xFF] });
assert_eq!(vec![0x04, 0xFF], actual);
let packet_bytes = vec![0x00, 0x00, 0x00, 0x00];
let actual = prepend_packet_type(packet_type, packet_bytes);
assert_eq!(vec![0x04, 0x00, 0x00, 0x00, 0x00], actual);
}
#[test]
@@ -74,15 +75,11 @@ fn test_packet_roundtrip_with_type() {
}
#[derive(Debug, PartialEq)]
struct FakePacket {
bytes: Vec<u8>,
}
struct FakePacket;
impl FakePacket {
fn parse(bytes: &[u8]) -> Result<Self, Error> {
Ok(Self {
bytes: bytes.to_vec(),
})
fn parse(_bytes: &[u8]) -> Result<Self, Error> {
Ok(Self)
}
}
@@ -92,6 +89,6 @@ impl Packet for FakePacket {
}
fn to_vec(self) -> Vec<u8> {
self.bytes
Vec::new()
}
}
@@ -14,17 +14,7 @@
//! Devices and connections to them
#[cfg(feature = "unstable_extended_adv")]
use crate::wrapper::{
hci::packets::{
self, AdvertisingEventProperties, AdvertisingFilterPolicy, Enable, EnabledSet,
FragmentPreference, LeSetAdvertisingSetRandomAddressBuilder,
LeSetExtendedAdvertisingDataBuilder, LeSetExtendedAdvertisingEnableBuilder,
LeSetExtendedAdvertisingParametersBuilder, Operation, OwnAddressType, PeerAddressType,
PrimaryPhyType, SecondaryPhyType,
},
ConversionError,
};
use crate::internal::hci::WithPacketType;
use crate::{
adv::AdvertisementDataBuilder,
wrapper::{
@@ -32,7 +22,7 @@ use crate::{
gatt_client::{ProfileServiceProxy, ServiceProxy},
hci::{
packets::{Command, ErrorCode, Event},
Address, HciCommand, WithPacketType,
Address, HciCommandWrapper,
},
host::Host,
l2cap::LeConnectionOrientedChannel,
@@ -49,9 +39,6 @@ use pyo3::{
use pyo3_asyncio::tokio::into_future;
use std::path;
#[cfg(test)]
mod tests;
/// Represents the various properties of some device
pub struct DeviceConfiguration(PyObject);
@@ -82,24 +69,11 @@ impl ToPyObject for DeviceConfiguration {
}
}
/// Used for tracking what advertising state a device might be in
#[derive(PartialEq)]
enum AdvertisingStatus {
AdvertisingLegacy,
AdvertisingExtended,
NotAdvertising,
}
/// A device that can send/receive HCI frames.
pub struct Device {
obj: PyObject,
advertising_status: AdvertisingStatus,
}
#[derive(Clone)]
pub struct Device(PyObject);
impl Device {
#[cfg(feature = "unstable_extended_adv")]
const ADVERTISING_HANDLE_EXTENDED: u8 = 0x00;
/// Creates a Device. When optional arguments are not specified, the Python object specifies the
/// defaults.
pub fn new(
@@ -120,10 +94,7 @@ impl Device {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call((), Some(kwargs))
.map(|any| Self {
obj: any.into(),
advertising_status: AdvertisingStatus::NotAdvertising,
})
.map(|any| Self(any.into()))
})
}
@@ -140,38 +111,28 @@ impl Device {
intern!(py, "from_config_file_with_hci"),
(device_config, source.0, sink.0),
)
.map(|any| Self {
obj: any.into(),
advertising_status: AdvertisingStatus::NotAdvertising,
})
.map(|any| Self(any.into()))
})
}
/// Create a Device configured to communicate with a controller through an HCI source/sink
pub fn with_hci(name: &str, address: Address, source: Source, sink: Sink) -> PyResult<Self> {
pub fn with_hci(name: &str, address: &str, source: Source, sink: Sink) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call_method1(intern!(py, "with_hci"), (name, address.0, source.0, sink.0))
.map(|any| Self {
obj: any.into(),
advertising_status: AdvertisingStatus::NotAdvertising,
})
.call_method1(intern!(py, "with_hci"), (name, address, source.0, sink.0))
.map(|any| Self(any.into()))
})
}
/// Sends an HCI command on this Device, returning the command's event result.
///
/// When `check_result` is `true`, then an `Err` will be returned if the controller's response
/// did not have an event code of "success".
pub async fn send_command(&self, command: Command, check_result: bool) -> PyResult<Event> {
let bumble_hci_command = HciCommand::try_from(command)?;
pub async fn send_command(&self, command: &Command, check_result: bool) -> PyResult<Event> {
Python::with_gil(|py| {
self.obj
self.0
.call_method1(
py,
intern!(py, "send_command"),
(bumble_hci_command, check_result),
(HciCommandWrapper(command.clone()), check_result),
)
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
@@ -190,7 +151,7 @@ impl Device {
/// Turn the device on
pub async fn power_on(&self) -> PyResult<()> {
Python::with_gil(|py| {
self.obj
self.0
.call_method0(py, intern!(py, "power_on"))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
@@ -201,7 +162,7 @@ impl Device {
/// Connect to a peer
pub async fn connect(&self, peer_addr: &str) -> PyResult<Connection> {
Python::with_gil(|py| {
self.obj
self.0
.call_method1(py, intern!(py, "connect"), (peer_addr,))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
@@ -219,7 +180,7 @@ impl Device {
});
Python::with_gil(|py| {
self.obj
self.0
.call_method1(py, intern!(py, "add_listener"), ("connection", boxed))
})
.map(|_| ())
@@ -230,7 +191,7 @@ impl Device {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("filter_duplicates", filter_duplicates)?;
self.obj
self.0
.call_method(py, intern!(py, "start_scanning"), (), Some(kwargs))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
@@ -248,7 +209,7 @@ impl Device {
});
Python::with_gil(|py| {
self.obj
self.0
.call_method1(py, intern!(py, "add_listener"), ("advertisement", boxed))
})
.map(|_| ())
@@ -257,7 +218,7 @@ impl Device {
/// Set the advertisement data to be used when [Device::start_advertising] is called.
pub fn set_advertising_data(&mut self, adv_data: AdvertisementDataBuilder) -> PyResult<()> {
Python::with_gil(|py| {
self.obj.setattr(
self.0.setattr(
py,
intern!(py, "advertising_data"),
adv_data.into_bytes().as_slice(),
@@ -269,162 +230,35 @@ impl Device {
/// Returns the host used by the device, if any
pub fn host(&mut self) -> PyResult<Option<Host>> {
Python::with_gil(|py| {
self.obj
self.0
.getattr(py, intern!(py, "host"))
.map(|obj| obj.into_option(Host::from))
})
}
/// Start advertising the data set with [Device.set_advertisement].
///
/// When `auto_restart` is set to `true`, then the device will automatically restart advertising
/// when a connected device is disconnected.
pub async fn start_advertising(&mut self, auto_restart: bool) -> PyResult<()> {
if self.advertising_status == AdvertisingStatus::AdvertisingExtended {
return Err(PyErr::new::<PyException, _>("Already advertising in extended mode. Stop the existing extended advertisement to start a legacy advertisement."));
}
// Bumble allows (and currently ignores) calling `start_advertising` when already
// advertising. Because that behavior may change in the future, we continue to delegate the
// handling to bumble.
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("auto_restart", auto_restart)?;
self.obj
self.0
.call_method(py, intern!(py, "start_advertising"), (), Some(kwargs))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())?;
self.advertising_status = AdvertisingStatus::AdvertisingLegacy;
Ok(())
}
/// Start advertising the data set in extended mode, replacing any existing extended adv. The
/// advertisement will be non-connectable.
///
/// Fails if the device is already advertising in legacy mode.
#[cfg(feature = "unstable_extended_adv")]
pub async fn start_advertising_extended(
&mut self,
adv_data: AdvertisementDataBuilder,
) -> PyResult<()> {
// TODO: add tests when local controller object supports extended advertisement commands (github.com/google/bumble/pull/238)
match self.advertising_status {
AdvertisingStatus::AdvertisingLegacy => return Err(PyErr::new::<PyException, _>("Already advertising in legacy mode. Stop the existing legacy advertisement to start an extended advertisement.")),
// Stop the current extended advertisement before advertising with new data.
// We could just issue an LeSetExtendedAdvertisingData command, but this approach
// allows better future flexibility if `start_advertising_extended` were to change.
AdvertisingStatus::AdvertisingExtended => self.stop_advertising_extended().await?,
_ => {}
}
// set extended params
let properties = AdvertisingEventProperties {
connectable: 0,
scannable: 0,
directed: 0,
high_duty_cycle: 0,
legacy: 0,
anonymous: 0,
tx_power: 0,
};
let extended_advertising_params_cmd = LeSetExtendedAdvertisingParametersBuilder {
advertising_event_properties: properties,
advertising_filter_policy: AdvertisingFilterPolicy::AllDevices,
advertising_handle: Self::ADVERTISING_HANDLE_EXTENDED,
advertising_sid: 0,
advertising_tx_power: 0,
own_address_type: OwnAddressType::RandomDeviceAddress,
peer_address: default_ignored_peer_address(),
peer_address_type: PeerAddressType::PublicDeviceOrIdentityAddress,
primary_advertising_channel_map: 7,
primary_advertising_interval_max: 200,
primary_advertising_interval_min: 100,
primary_advertising_phy: PrimaryPhyType::Le1m,
scan_request_notification_enable: Enable::Disabled,
secondary_advertising_max_skip: 0,
secondary_advertising_phy: SecondaryPhyType::Le1m,
};
self.send_command(extended_advertising_params_cmd.into(), true)
.await?;
// set random address
let random_address: packets::Address =
self.random_address()?.try_into().map_err(|e| match e {
ConversionError::Python(pyerr) => pyerr,
ConversionError::Native(e) => PyErr::new::<PyException, _>(format!("{e:?}")),
})?;
let random_address_cmd = LeSetAdvertisingSetRandomAddressBuilder {
advertising_handle: Self::ADVERTISING_HANDLE_EXTENDED,
random_address,
};
self.send_command(random_address_cmd.into(), true).await?;
// set adv data
let advertising_data_cmd = LeSetExtendedAdvertisingDataBuilder {
advertising_data: adv_data.into_bytes(),
advertising_handle: Self::ADVERTISING_HANDLE_EXTENDED,
fragment_preference: FragmentPreference::ControllerMayFragment,
operation: Operation::CompleteAdvertisement,
};
self.send_command(advertising_data_cmd.into(), true).await?;
// enable adv
let extended_advertising_enable_cmd = LeSetExtendedAdvertisingEnableBuilder {
enable: Enable::Enabled,
enabled_sets: vec![EnabledSet {
advertising_handle: Self::ADVERTISING_HANDLE_EXTENDED,
duration: 0,
max_extended_advertising_events: 0,
}],
};
self.send_command(extended_advertising_enable_cmd.into(), true)
.await?;
self.advertising_status = AdvertisingStatus::AdvertisingExtended;
Ok(())
.map(|_| ())
}
/// Stop advertising.
pub async fn stop_advertising(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.obj
self.0
.call_method0(py, intern!(py, "stop_advertising"))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())?;
if self.advertising_status == AdvertisingStatus::AdvertisingLegacy {
self.advertising_status = AdvertisingStatus::NotAdvertising;
}
Ok(())
}
/// Stop advertising extended.
#[cfg(feature = "unstable_extended_adv")]
pub async fn stop_advertising_extended(&mut self) -> PyResult<()> {
if AdvertisingStatus::AdvertisingExtended != self.advertising_status {
return Ok(());
}
// disable adv
let extended_advertising_enable_cmd = LeSetExtendedAdvertisingEnableBuilder {
enable: Enable::Disabled,
enabled_sets: vec![EnabledSet {
advertising_handle: Self::ADVERTISING_HANDLE_EXTENDED,
duration: 0,
max_extended_advertising_events: 0,
}],
};
self.send_command(extended_advertising_enable_cmd.into(), true)
.await?;
self.advertising_status = AdvertisingStatus::NotAdvertising;
Ok(())
.map(|_| ())
}
/// Registers an L2CAP connection oriented channel server. When a client connects to the server,
@@ -452,7 +286,7 @@ impl Device {
kwargs.set_opt_item("max_credits", max_credits)?;
kwargs.set_opt_item("mtu", mtu)?;
kwargs.set_opt_item("mps", mps)?;
self.obj.call_method(
self.0.call_method(
py,
intern!(py, "register_l2cap_channel_server"),
(),
@@ -461,15 +295,6 @@ impl Device {
})?;
Ok(())
}
/// Gets the Device's `random_address` property
pub fn random_address(&self) -> PyResult<Address> {
Python::with_gil(|py| {
self.obj
.getattr(py, intern!(py, "random_address"))
.map(Address)
})
}
}
/// A connection to a remote device.
@@ -626,13 +451,3 @@ impl Advertisement {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "data")).map(AdvertisingData))
}
}
/// Use this address when sending an HCI command that requires providing a peer address, but the
/// command is such that the peer address will be ignored.
///
/// Internal to bumble, this address might mean "any", but a packets::Address typically gets sent
/// directly to a controller, so we don't have to worry about it.
#[cfg(feature = "unstable_extended_adv")]
fn default_ignored_peer_address() -> packets::Address {
packets::Address::try_from(0x0000_0000_0000_u64).unwrap()
}
-23
View File
@@ -1,23 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "unstable_extended_adv")]
use crate::wrapper::device::default_ignored_peer_address;
#[test]
#[cfg(feature = "unstable_extended_adv")]
fn default_peer_address_does_not_panic() {
let result = std::panic::catch_unwind(default_ignored_peer_address);
assert!(result.is_ok())
}
+26 -59
View File
@@ -14,19 +14,18 @@
//! HCI
// re-export here, and internal usages of these imports should refer to this mod, not the internal
// mod
pub(crate) use crate::internal::hci::WithPacketType;
pub use crate::internal::hci::{packets, Error, Packet};
use crate::wrapper::{
hci::packets::{AddressType, Command, ErrorCode},
ConversionError,
use crate::{
internal::hci::WithPacketType,
wrapper::hci::packets::{AddressType, Command, ErrorCode},
};
use itertools::Itertools as _;
use pyo3::{
exceptions::PyException, intern, types::PyModule, FromPyObject, IntoPy, PyAny, PyErr, PyObject,
PyResult, Python, ToPyObject,
exceptions::PyException,
intern, pyclass, pymethods,
types::{PyBytes, PyModule},
FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
};
/// Provides helpers for interacting with HCI
@@ -44,45 +43,17 @@ impl HciConstant {
}
}
/// Bumble's representation of an HCI command.
pub(crate) struct HciCommand(pub(crate) PyObject);
impl HciCommand {
fn from_bytes(bytes: &[u8]) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.hci"))?
.getattr(intern!(py, "HCI_Command"))?
.call_method1(intern!(py, "from_bytes"), (bytes,))
.map(|obj| Self(obj.to_object(py)))
})
}
}
impl TryFrom<Command> for HciCommand {
type Error = PyErr;
fn try_from(value: Command) -> Result<Self, Self::Error> {
HciCommand::from_bytes(&value.to_vec_with_packet_type())
}
}
impl IntoPy<PyObject> for HciCommand {
fn into_py(self, _py: Python<'_>) -> PyObject {
self.0
}
}
/// A Bluetooth address
#[derive(Clone)]
pub struct Address(pub(crate) PyObject);
impl Address {
/// Creates a new [Address] object.
pub fn new(address: &str, address_type: AddressType) -> PyResult<Self> {
/// Creates a new [Address] object
pub fn new(address: &str, address_type: &AddressType) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Address"))?
.call1((address, address_type))
.call1((address, address_type.to_object(py)))
.map(|any| Self(any.into()))
})
}
@@ -147,31 +118,27 @@ impl ToPyObject for Address {
}
}
/// An error meaning that the u64 value did not represent a valid BT address.
#[derive(Debug)]
pub struct InvalidAddress(u64);
/// Implements minimum necessary interface to be treated as bumble's [HCI_Command].
/// While pyo3's macros do not support generics, this could probably be refactored to allow multiple
/// implementations of the HCI_Command methods in the future, if needed.
#[pyclass]
pub(crate) struct HciCommandWrapper(pub(crate) Command);
impl TryInto<packets::Address> for Address {
type Error = ConversionError<InvalidAddress>;
#[pymethods]
impl HciCommandWrapper {
fn __bytes__(&self, py: Python) -> PyResult<PyObject> {
let bytes = PyBytes::new(py, &self.0.clone().to_vec_with_packet_type());
Ok(bytes.into_py(py))
}
fn try_into(self) -> Result<packets::Address, Self::Error> {
let addr_le_bytes = self.as_le_bytes().map_err(ConversionError::Python)?;
// packets::Address only supports converting from a u64 (TODO: update if/when it supports converting from [u8; 6] -- https://github.com/google/pdl/issues/75)
// So first we take the python `Address` little-endian bytes (6 bytes), copy them into a
// [u8; 8] in little-endian format, and finally convert it into a u64.
let mut buf = [0_u8; 8];
buf[0..6].copy_from_slice(&addr_le_bytes);
let address_u64 = u64::from_le_bytes(buf);
packets::Address::try_from(address_u64)
.map_err(InvalidAddress)
.map_err(ConversionError::Native)
#[getter]
fn op_code(&self) -> u16 {
self.0.get_op_code().into()
}
}
impl IntoPy<PyObject> for AddressType {
fn into_py(self, py: Python<'_>) -> PyObject {
impl ToPyObject for AddressType {
fn to_object(&self, py: Python<'_>) -> PyObject {
u8::from(self).to_object(py)
}
}
-9
View File
@@ -132,12 +132,3 @@ pub(crate) fn wrap_python_async<'a>(py: Python<'a>, function: &'a PyAny) -> PyRe
.getattr(intern!(py, "wrap_async"))?
.call1((function,))
}
/// Represents the two major kinds of errors that can occur when converting between Rust and Python.
pub enum ConversionError<T> {
/// Occurs across the Python/native boundary.
Python(PyErr),
/// Occurs within the native ecosystem, such as when performing more transformations before
/// finally converting to the native type.
Native(T),
}
+3 -4
View File
@@ -15,7 +15,6 @@
//! HCI packet transport
use crate::wrapper::controller::Controller;
use futures::executor::block_on;
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python};
/// A source/sink pair for HCI packet I/O.
@@ -59,9 +58,9 @@ impl Transport {
impl Drop for Transport {
fn drop(&mut self) {
// don't spawn a thread to handle closing, as it may get dropped at program termination,
// resulting in `RuntimeWarning: coroutine ... was never awaited` from Python
let _ = block_on(self.close());
// can't await in a Drop impl, but we can at least spawn a task to do it
let obj = self.0.clone();
tokio::spawn(async move { Self(obj).close().await });
}
}
+1 -1
View File
@@ -33,7 +33,7 @@ include_package_data = True
install_requires =
aiohttp ~= 3.8; platform_system!='Emscripten'
appdirs >= 1.4; platform_system!='Emscripten'
bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
bt-test-interfaces >= 0.0.5; platform_system!='Emscripten'
click == 8.1.3; platform_system!='Emscripten'
cryptography == 39; platform_system!='Emscripten'
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
-75
View File
@@ -34,8 +34,6 @@ from bumble.pairing import PairingConfig, PairingDelegate
from bumble.smp import (
SMP_PAIRING_NOT_SUPPORTED_ERROR,
SMP_CONFIRM_VALUE_FAILED_ERROR,
OobContext,
OobLegacyContext,
)
from bumble.core import ProtocolError
from bumble.hci import HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
@@ -577,77 +575,6 @@ async def test_self_smp_public_address():
await _test_self_smp_with_configs(pairing_config, pairing_config)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_smp_oob_sc():
oob_context_1 = OobContext()
oob_context_2 = OobContext()
pairing_config_1 = PairingConfig(
mitm=True,
sc=True,
bonding=True,
oob=PairingConfig.OobConfig(oob_context_1, oob_context_2.share(), None),
)
pairing_config_2 = PairingConfig(
mitm=True,
sc=True,
bonding=True,
oob=PairingConfig.OobConfig(oob_context_2, oob_context_1.share(), None),
)
await _test_self_smp_with_configs(pairing_config_1, pairing_config_2)
pairing_config_3 = PairingConfig(
mitm=True,
sc=True,
bonding=True,
oob=PairingConfig.OobConfig(oob_context_2, None, None),
)
await _test_self_smp_with_configs(pairing_config_1, pairing_config_3)
await _test_self_smp_with_configs(pairing_config_3, pairing_config_1)
pairing_config_4 = PairingConfig(
mitm=True,
sc=True,
bonding=True,
oob=PairingConfig.OobConfig(oob_context_2, oob_context_2.share(), None),
)
with pytest.raises(ProtocolError) as error:
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
with pytest.raises(ProtocolError):
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_smp_oob_legacy():
legacy_context = OobLegacyContext()
pairing_config_1 = PairingConfig(
mitm=True,
sc=False,
bonding=True,
oob=PairingConfig.OobConfig(None, None, legacy_context),
)
pairing_config_2 = PairingConfig(
mitm=True,
sc=True,
bonding=True,
oob=PairingConfig.OobConfig(OobContext(), None, legacy_context),
)
await _test_self_smp_with_configs(pairing_config_1, pairing_config_2)
await _test_self_smp_with_configs(pairing_config_2, pairing_config_1)
# -----------------------------------------------------------------------------
async def run_test_self():
await test_self_connection()
@@ -658,8 +585,6 @@ async def run_test_self():
await test_self_smp_wrong_pin()
await test_self_smp_over_classic()
await test_self_smp_public_address()
await test_self_smp_oob_sc()
await test_self_smp_oob_legacy()
# -----------------------------------------------------------------------------
-23
View File
@@ -17,16 +17,11 @@
# -----------------------------------------------------------------------------
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address
from bumble.core import AdvertisingData
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def reversed_hex(hex_str):
return bytes(reversed(bytes.fromhex(hex_str)))
@@ -238,23 +233,6 @@ def test_ah():
assert value == expected
# -----------------------------------------------------------------------------
def test_oob_data():
oob_data = OobData(
address=Address("F0:F1:F2:F3:F4:F5"),
role=LeRole.BOTH_PERIPHERAL_PREFERRED,
shared_data=OobSharedData(c=bytes([1, 2]), r=bytes([3, 4])),
)
oob_data_ad = oob_data.to_ad()
oob_data_bytes = bytes(oob_data_ad)
oob_data_ad_parsed = AdvertisingData.from_bytes(oob_data_bytes)
oob_data_parsed = OobData.from_ad(oob_data_ad_parsed)
assert oob_data_parsed.address == oob_data.address
assert oob_data_parsed.role == oob_data.role
assert oob_data_parsed.shared_data.c == oob_data.shared_data.c
assert oob_data_parsed.shared_data.r == oob_data.shared_data.r
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ecc()
@@ -268,4 +246,3 @@ if __name__ == '__main__':
test_h6()
test_h7()
test_ah()
test_oob_data()