Merge pull request #483 from zxzxwu/rfc

RFCOMM: Handle packets received before DLC sink set
This commit is contained in:
zxzxwu
2024-05-10 16:34:57 +08:00
committed by GitHub
3 changed files with 62 additions and 5 deletions

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import collections
import dataclasses import dataclasses
import enum import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
@@ -54,6 +55,7 @@ logger = logging.getLogger(__name__)
# fmt: off # fmt: off
RFCOMM_PSM = 0x0003 RFCOMM_PSM = 0x0003
DEFAULT_RX_QUEUE_SIZE = 32
class FrameType(enum.IntEnum): class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
@@ -445,7 +447,8 @@ class DLC(EventEmitter):
RESET = 0x05 RESET = 0x05
connection_result: Optional[asyncio.Future] connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]] _sink: Optional[Callable[[bytes], None]]
_enqueued_rx_packets: collections.deque[bytes]
def __init__( def __init__(
self, self,
@@ -466,10 +469,12 @@ class DLC(EventEmitter):
self.state = DLC.State.INIT self.state = DLC.State.INIT
self.role = multiplexer.role self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0 self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None self.connection_result = None
self.drained = asyncio.Event() self.drained = asyncio.Event()
self.drained.set() self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets = collections.deque(maxlen=DEFAULT_RX_QUEUE_SIZE)
self._sink = None
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -477,6 +482,19 @@ class DLC(EventEmitter):
max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
) )
@property
def sink(self) -> Optional[Callable[[bytes], None]]:
return self._sink
@sink.setter
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
for packet in self._enqueued_rx_packets:
sink(packet) # pylint: disable=not-callable
self._enqueued_rx_packets.clear()
def change_state(self, new_state: State) -> None: def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}') logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
self.state = new_state self.state = new_state
@@ -549,8 +567,15 @@ class DLC(EventEmitter):
f'rx_credits={self.rx_credits}: {data.hex()}' f'rx_credits={self.rx_credits}: {data.hex()}'
) )
if data: if data:
if self.sink: if self._sink:
self.sink(data) # pylint: disable=not-callable self._sink(data) # pylint: disable=not-callable
else:
self._enqueued_rx_packets.append(data)
if (
self._enqueued_rx_packets.maxlen
and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
):
logger.warning(f'DLC [{self.dlci}] received packet queue is full')
# Update the credits # Update the credits
if self.rx_credits > 0: if self.rx_credits > 0:

View File

@@ -32,6 +32,8 @@ from bumble.rfcomm import (
RFCOMM_PSM, RFCOMM_PSM,
) )
_TIMEOUT = 0.1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def basic_frame_check(x): def basic_frame_check(x):
@@ -82,6 +84,29 @@ async def test_basic_connection() -> None:
assert await queues[0].get() == b'Lorem ipsum dolor sit amet' assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_receive_pdu_before_open_dlc_returns() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
DATA = b'123'
accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)
assert devices.connections[1]
multiplexer = await Client(devices.connections[1]).start()
open_dlc_task = asyncio.create_task(multiplexer.open_dlc(channel))
dlc_responder = await accept_future
dlc_responder.write(DATA)
dlc_initiator = await open_dlc_task
dlc_initiator_queue = asyncio.Queue() # type: ignore[var-annotated]
dlc_initiator.sink = dlc_initiator_queue.put_nowait
assert await asyncio.wait_for(dlc_initiator_queue.get(), timeout=_TIMEOUT) == DATA
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_record(): async def test_service_record():

View File

@@ -16,7 +16,8 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional, Type
from typing_extensions import Self
from bumble.controller import Controller from bumble.controller import Controller
from bumble.link import LocalLink from bumble.link import LocalLink
@@ -81,6 +82,12 @@ class TwoDevices:
def __getitem__(self, index: int) -> Device: def __getitem__(self, index: int) -> Device:
return self.devices[index] return self.devices[index]
@classmethod
async def create_with_connection(cls: Type[Self]) -> Self:
devices = cls()
await devices.setup_connection()
return devices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_barrier(): async def async_barrier():