Compare commits

...

7 Commits

Author SHA1 Message Date
zxzxwu
e4364d18a7 Merge pull request #418 from zxzxwu/rfc
RFCOMM: Slightly refactor and correct constants
2024-02-01 01:30:53 +08:00
Josh Wu
6a34c9f224 RFCOMM: Slightly refactor and correct constants 2024-02-01 01:18:56 +08:00
zxzxwu
071fc2723a Merge pull request #376 from zxzxwu/host
Manage lifecycle of CIS and SCO links in host
2024-01-28 22:09:08 +08:00
zxzxwu
ef4ea86f58 Merge pull request #381 from zxzxwu/offload
Support non-directed address generation offload
2024-01-28 22:08:32 +08:00
Gilles Boccon-Gibod
dfdaa149d0 Merge pull request #337 from google/gbg/avrcp
Add AVRCP support
2024-01-28 01:27:52 -08:00
Josh Wu
c40824e51c Support non-directed address generation offload 2024-01-26 16:02:40 +08:00
Josh Wu
da60386385 Manage lifecycle of CIS and SCO links in host 2024-01-18 11:56:38 +08:00
8 changed files with 171 additions and 150 deletions

View File

@@ -21,6 +21,7 @@ import functools
import json
import asyncio
import logging
import secrets
from contextlib import asynccontextmanager, AsyncExitStack, closing
from dataclasses import dataclass
from collections.abc import Iterable
@@ -994,12 +995,15 @@ class DeviceConfiguration:
irk = config.get('irk')
if irk:
self.irk = bytes.fromhex(irk)
else:
elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
# Construct an IRK from the address bytes
# NOTE: this is not secure, but will always give the same IRK for the same
# address
address_bytes = bytes(self.address)
self.irk = (address_bytes * 3)[:16]
else:
# Fallback - when both IRK and address are not set, randomly generate an IRK.
self.irk = secrets.token_bytes(16)
# Load advertising data
advertising_data = config.get('advertising_data')
@@ -1581,6 +1585,16 @@ class Device(CompositeEventEmitter):
if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
# Add an empty entry for non-directed address generation.
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=Address.ANY.address_type,
peer_identity_address=Address.ANY,
peer_irk=bytes(16),
local_irk=self.irk,
)
)
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
@@ -1681,8 +1695,8 @@ class Device(CompositeEventEmitter):
peer_address = target
peer_address_type = target.address_type
else:
peer_address = Address('00:00:00:00:00:00')
peer_address_type = Address.PUBLIC_DEVICE_ADDRESS
peer_address = Address.ANY
peer_address_type = Address.ANY.address_type
# Set the advertising parameters
await self.send_command(
@@ -3066,34 +3080,30 @@ class Device(CompositeEventEmitter):
cig_id=cig_id,
)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()
with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}
@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
return await asyncio.gather(*pending_cis_establishments.values())
# [LE only]
@@ -3741,7 +3751,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)
# [LE only]
@@ -3821,7 +3831,7 @@ class Device(CompositeEventEmitter):
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)
@@ -3829,7 +3839,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)
@host_event_handler

View File

@@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct
@@ -161,9 +162,25 @@ class Connection:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
@@ -183,6 +200,8 @@ class Host(AbortableEventEmitter):
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.local_version = None
@@ -696,25 +715,36 @@ class Host(AbortableEventEmitter):
def on_hci_disconnection_complete_event(self, event):
# Find the connection
if (connection := self.connections.get(event.connection_handle)) is None:
handle = event.connection_handle
if (
connection := (
self.connections.get(handle)
or self.cis_links.get(handle)
or self.sco_links.get(handle)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return
if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
self.emit('disconnection', handle, event.reason)
(
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0)
)
else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('disconnection_failure', event.connection_handle, event.status)
self.emit('disconnection_failure', handle, event.status)
def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -775,6 +805,10 @@ class Host(AbortableEventEmitter):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
@@ -841,6 +875,11 @@ class Host(AbortableEventEmitter):
f'{event.bd_addr}'
)
self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)
# Notify the client
self.emit(
'sco_connection',

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
import asyncio
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
@@ -60,27 +61,18 @@ logger = logging.getLogger(__name__)
RFCOMM_PSM = 0x0003
class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
# Frame types
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum):
PN = 0x20
MSC = 0x38
RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}
# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38
# FCS CRC
CRC_TABLE = bytes([
@@ -118,7 +110,7 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_WINDOW_SIZE = 16
RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
@@ -183,7 +175,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame:
def __init__(
self,
frame_type: int,
frame_type: FrameType,
c_r: int,
dlci: int,
p_f: int,
@@ -206,14 +198,11 @@ class RFCOMM_Frame:
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
if frame_type == FrameType.UIH:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2
@@ -237,24 +226,24 @@ class RFCOMM_Frame:
@staticmethod
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
@staticmethod
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
@staticmethod
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
@staticmethod
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
@staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
@@ -262,7 +251,7 @@ class RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = data[1] & 0xEF
frame_type = FrameType(data[1] & 0xEF)
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -291,7 +280,7 @@ class RFCOMM_Frame:
def __str__(self) -> str:
return (
f'{color(self.type_name(), "yellow")}'
f'{color(self.type.name, "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
@@ -301,6 +290,7 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN:
dlci: int
cl: int
@@ -310,23 +300,11 @@ class RFCOMM_MCC_PN:
max_retransmissions: int
window_size: int
def __init__(
self,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
def __post_init__(self) -> None:
if self.window_size < 1 or self.window_size > 7:
logger.warning(
f'Error Recovery Window size {self.window_size} is out of range [1, 7].'
)
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
@@ -337,7 +315,7 @@ class RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7],
window_size=data[7] & 0x07,
)
def __bytes__(self) -> bytes:
@@ -350,23 +328,14 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF,
# Only 3 bits are meaningful.
self.window_size & 0x07,
]
)
def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC:
dlci: int
fc: int
@@ -375,16 +344,6 @@ class RFCOMM_MCC_MSC:
ic: int
dv: int
def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC(
@@ -409,16 +368,6 @@ class RFCOMM_MCC_MSC:
]
)
def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
@@ -471,7 +420,7 @@ class DLC(EventEmitter):
self.multiplexer.send_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -485,9 +434,7 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -503,9 +450,7 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -559,9 +504,7 @@ class DLC(EventEmitter):
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
@@ -589,7 +532,7 @@ class DLC(EventEmitter):
max_retransmissions=0,
window_size=self.window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING)
@@ -711,7 +654,7 @@ class Multiplexer(EventEmitter):
if frame.dlci == 0:
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
if frame.type == FrameType.DM:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
@@ -724,7 +667,7 @@ class Multiplexer(EventEmitter):
dlc.on_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -772,10 +715,10 @@ class Multiplexer(EventEmitter):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == RFCOMM_MCC_PN_TYPE:
if mcc_type == MccType.PN:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
elif mcc_type == MccType.MSC:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -871,7 +814,7 @@ class Multiplexer(EventEmitter):
max_retransmissions=0,
window_size=window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING)

View File

@@ -1993,10 +1993,8 @@ class Manager(EventEmitter):
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
self.device.abort_on(
'flush', self.device.update_keys(str(identity_address), keys)
)
# Make sure on_pairing emits after key update.
await self.device.update_keys(str(identity_address), keys)
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -99,7 +99,7 @@ development =
types-protobuf >= 4.21.0
avatar =
pandora-avatar == 0.0.5
rootcanal == 1.4.0 ; python_version>='3.10'
rootcanal == 1.6.0 ; python_version>='3.10'
documentation =
mkdocs >= 1.4.0
mkdocs-material >= 8.5.6

View File

@@ -467,9 +467,8 @@ async def test_cis():
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2
# TODO: Fix Host CIS support.
# await cis_links[0].disconnect()
# await cis_links[1].disconnect()
await cis_links[0].disconnect()
await cis_links[1].disconnect()
# -----------------------------------------------------------------------------

View File

@@ -24,11 +24,11 @@ from typing import Tuple
from .test_utils import TwoDevices
from bumble import core
from bumble import device
from bumble import hfp
from bumble import rfcomm
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -109,7 +109,7 @@ async def test_sco_setup():
devices[1].accept(devices[0].public_address),
)
def on_sco_request(_connection: device.Connection, _link_type: int):
def on_sco_request(_connection, _link_type: int):
connections[1].abort_on(
'disconnection',
devices[1].send_command(
@@ -124,17 +124,13 @@ async def test_sco_setup():
devices[1].on('sco_request', on_sco_request)
sco_connections = [
sco_connection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
devices[0].on(
'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link)
)
devices[1].on(
'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link)
)
for device, future in zip(devices, sco_connection_futures):
device.on('sco_connection', future.set_result)
await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
@@ -142,8 +138,17 @@ async def test_sco_setup():
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
)
)
sco_connections = await asyncio.gather(*sco_connection_futures)
await asyncio.gather(*sco_connections)
sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)
# -----------------------------------------------------------------------------

View File

@@ -15,7 +15,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.rfcomm import RFCOMM_Frame
import asyncio
import pytest
from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
# -----------------------------------------------------------------------------
@@ -43,6 +47,29 @@ def test_frames():
basic_frame_check(frame)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
devices = test_utils.TwoDevices()
await devices.setup_connection()
accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)
multiplexer = await Client(devices.connections[1]).start()
dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel))
queues = [asyncio.Queue(), asyncio.Queue()]
for dlc, queue in zip(dlcs, queues):
dlc.sink = queue.put_nowait
dlcs[0].write(b'The quick brown fox jumps over the lazy dog')
assert await queues[1].get() == b'The quick brown fox jumps over the lazy dog'
dlcs[1].write(b'Lorem ipsum dolor sit amet')
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()