Compare commits

...

30 Commits

Author SHA1 Message Date
Josh Wu 72d821b1f6 Merge pull request #928 from zxzxwu/avdtp
AVDTP: Avoid explicit in_use management
2026-05-26 16:33:08 +08:00
Josh Wu afe064b4ea AVDTP: Make local stream endpoint in_use dyanmic property 2026-05-22 15:58:11 +08:00
Josh Wu 8d0cef70c2 AVDTP: Add keyword argument to long __init__ 2026-05-20 16:19:06 +08:00
Josh Wu 9cefde1c3e Merge pull request #923 from laozhuxinlu/avdtp_abort_issue_fix
Fix AVDTP endpoint resource leak by clearing the in_use flag on strea…
2026-05-20 16:18:35 +08:00
Clay Zhu ffb9d5f117 Fix AVDTP endpoint resource leak by clearing the in_use flag on stream close and abort commands. 2026-05-11 18:44:56 +08:00
Josh Wu 7d3be8157a Merge pull request #922 from zxzxwu/typing
Type some optional attributes
2026-05-09 16:15:42 +08:00
Josh Wu 9dc9c348e5 Merge pull request #920 from zxzxwu/avdtp
AVDTP: Make all handlers async
2026-05-07 17:39:24 +08:00
Josh Wu b18555539e Type some optional attributes 2026-05-06 17:16:40 +08:00
Josh Wu 8a853d5b2f AVDTP: Make all handlers async 2026-05-05 01:44:10 +08:00
Josh Wu 8988a85245 Merge pull request #919 from zxzxwu/sdp
SDP: Move parser functions to parser class
2026-04-29 13:21:13 +08:00
Josh Wu 0813da2278 SDP: Move parser functions to parser class 2026-04-28 13:27:50 +08:00
Gilles Boccon-Gibod a1ff183d44 Merge pull request #915 from dlech/notify-subscribers-type-hints
improve type hints for notify/indicate subscriber(s) methods
2026-04-27 21:45:38 +02:00
Gilles Boccon-Gibod 7adf44eddf Merge pull request #916 from dlech/fix-crash-in-attribute-repr
fix crash in `bumble.att.Attribute.__repr__`
2026-04-27 21:41:41 +02:00
Josh Wu 05accbf805 Merge pull request #918 from ibondarenko1/fix/avdtp-empty-pdu-guard
avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
2026-04-27 10:01:51 +08:00
Josh Wu 80f54f2a09 Merge pull request #917 from dlech/fix-regex-with-backslash
Fix regex syntax warning in sdp_test.py.
2026-04-27 09:55:36 +08:00
ibondarenko1 07b5e33e09 avdtp: address review nits — use truthy checks
Per @zxzxwu review on #918:
- bumble/avdtp.py: replace `if len(pdu) < 1:` with `if not pdu:`
- tests/avdtp_test.py: replace `assert completed == []` with
  `assert not completed`

Both are idiomatic Python truthy checks; behavior identical.
2026-04-26 18:49:55 -07:00
ibondarenko1 b874e26a4f avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
A remote peer can send an AVDTP frame shorter than the assembler expects.
The current MessageAssembler.on_pdu() unconditionally accesses pdu[0],
pdu[1], and (for START packets) pdu[2], so a 0-, 1-, or 2-byte frame
raises IndexError. The exception propagates up through L2CAP's read loop
and tears down the channel — same DoS class as #912 (empty ATT PDU) and
#914 (unbounded SDP recursion).

Fix: validate length before each access. Empty PDUs and packets shorter
than the type-specific minimum are logged and dropped; the assembler
stays alive so the L2CAP channel is not torn down.

- bumble/avdtp.py: length guards in MessageAssembler.on_pdu before
  accessing pdu[0], pdu[1], pdu[2].
- tests/avdtp_test.py: regression test covering empty PDU, 1-byte SINGLE,
  1-byte START, 2-byte START — all four would have raised IndexError
  pre-fix; assembler now drops without raising.
2026-04-26 18:16:15 -07:00
David Lechner baa5257780 improve type hints for notify/indicate subscriber(s) methods
Pyright expects generic type parameters to be specified for the
Attribute class, otherwise it treats the type as Unknown which can
trigger reportUnknownMemberType errors.

This can be solved by using a generic type parameter for these methods
which also has the benefit of making sure that the value parameter has
the correct type for the attribute.

In some cases, a new local `value_as_bytes` variable is needed to avoid
type errors and makes the code less confusing by not overwriting the
original `value` variable.
2026-04-26 09:43:40 -05:00
David Lechner a91ea9110c Fix regex syntax warning in sdp_test.py.
Change regex match string to raw string to avoid syntax warning:

    tests/sdp_test.py:218: SyntaxWarning: invalid escape sequence '\d'
    assert not re.search("Expect \d+ bytes, got \d+", caplog.text)

In the future, this will become an error, so we should fix it now.
2026-04-26 09:31:18 -05:00
Josh Wu 1686c5b11b Merge pull request #914 from ibondarenko1/fix/sdp-recursion-depth-limit
sdp: bound DataElement parse recursion to prevent RecursionError DoS
2026-04-26 17:22:59 +08:00
David Lechner d9481992bb fix crash in bumble.att.Attribute.__repr__
If an attribute does not contains a bytes value, it would crash with
something like:

    AttributeError: 'NoneType' object has no attribute 'hex'

Clearly, the intention here was to use `value_str` to avoid this
possibility.
2026-04-25 17:01:25 -05:00
ibondarenko1 16d0ed56cf sdp: address review nits (import at top, InvalidPacketError)
- bumble/sdp.py: replace raise ValueError with raise InvalidPacketError
  in DataElement.list_from_bytes depth guard. InvalidPacketError
  already imported at line 34 and extends ValueError so the existing
  regression test continues to match.
- tests/sdp_test.py: remove duplicate 'import pytest' inside
  test_nested_sequence_recursion_guard; pytest already imported at
  module top (line 23).

Threading.local counter left as-is per zxzxwu's 'leave it here and
refactor later' comment on the PR.
2026-04-24 11:42:49 -07:00
Ievgen Bondarenko c55eb156b8 sdp: fix lint formatting (black: blank line after import pytest) 2026-04-24 00:06:56 -07:00
ibondarenko1 8614881fb3 sdp: bound DataElement parse recursion to prevent RecursionError DoS
DataElement.from_bytes -> list_from_bytes -> (SEQUENCE/ALTERNATIVE
constructor dispatches back to list_from_bytes) had no depth limit. A
malicious SDP peer could send a PDU of a few kilobytes containing ~1000
nested SEQUENCE tags and exhaust the Python recursion stack, crashing the
host with an unhandled RecursionError propagating out of the SDP handler.

Reachable via: any remote Bluetooth device that Bumble performs SDP
service discovery against (default during Classic connection setup).

Same family as PR #912 (ATT_PDU.from_bytes empty PDU IndexError) - remote
unchecked-input parser crash in the Bluetooth stack.

Fix: thread-local depth counter, cap nesting at 32 (well above anything a
legitimate service record uses). Added two regression tests covering the
deep-nesting reject path and normal 16-level-nested SEQUENCE parsing.

Reproducer (4.5 KB payload, deterministic crash on 0.0.228):

    from bumble.sdp import DataElement
    inner = b"\x35\x00"
    for _ in range(1500):
        size = len(inner)
        if size < 65535:
            inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
    DataElement.from_bytes(inner)  # RecursionError before fix

Signed-off-by: ibondarenko1 <ibondarenko1@users.noreply.github.com>
2026-04-23 00:53:06 -07:00
Josh Wu 27d02ef18d Merge pull request #913 from zxzxwu/sdp
SDP: Fix wrong parameter size
2026-04-20 16:32:37 +08:00
Josh Wu c0725e2a4a SDP: Fix wrong parameter size 2026-04-20 16:23:19 +08:00
Josh Wu bf0784dde4 Merge pull request #912 from ibondarenko1/fix/empty-pdu-crash
fix: add input validation to prevent remote crash from empty/malforme…
2026-04-20 14:36:48 +08:00
Ievgen Bondarenko 444f43f6a3 fix: address review feedback - use InvalidPacketError and abort on buffer overflow
- att.py: raise core.InvalidPacketError instead of generic ValueError
- smp.py: raise core.InvalidPacketError instead of generic ValueError
- hfp.py: add MAX_BUFFER_SIZE class constant (64KB)
- hfp.py: drop incoming data when it would overflow buffer instead of
  truncating, preserving existing partial-packet state

Per review comments on PR #912 by @zxzxwu.
2026-04-16 11:24:09 -07:00
Gilles Boccon-Gibod 2420c47cf1 Merge pull request #911 from google/gbg/issue-910
release command semaphore after timeout
2026-04-16 18:11:57 +02:00
Ievgen Bondarenko 0a78e7506b fix: add input validation to prevent remote crash from empty/malformed PDUs
Add length checks in from_bytes() for ATT and SMP protocol parsers
to prevent IndexError crashes from empty PDUs sent by remote Bluetooth
devices. Also add buffer size limit and UTF-8 error handling in HFP
protocol to prevent memory exhaustion and decode crashes.

- bumble/att.py: validate PDU is non-empty before accessing pdu[0]
- bumble/smp.py: validate PDU is non-empty before accessing pdu[0]
- bumble/hfp.py: limit buffer to 64KB, handle invalid UTF-8 gracefully

These issues can be triggered by a remote Bluetooth device sending
malformed packets, causing denial of service on the host.
2026-04-16 01:43:41 -07:00
9 changed files with 593 additions and 377 deletions
+4 -2
View File
@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
from bumble.hci import HCI_Object
# -----------------------------------------------------------------------------
@@ -249,6 +249,8 @@ class ATT_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
if not pdu:
raise InvalidPacketError("Empty ATT PDU")
op_code = pdu[0]
subclass = ATT_PDU.pdu_classes.get(op_code)
@@ -1081,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}'
value_string = f', value={value_str}'
else:
value_string = ''
return (
+140 -77
View File
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
import asyncio
import enum
import logging
@@ -311,6 +312,13 @@ class MessageAssembler:
def on_pdu(self, pdu: bytes) -> None:
self.packet_count += 1
# Drop empty PDUs sent by remote — accessing pdu[0] below would
# raise IndexError, propagating up to the L2CAP read loop and
# tearing down the channel. Same class as #912 (ATT empty PDU).
if not pdu:
logger.warning('AVDTP message assembler: empty PDU dropped')
return
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
message_type = Message.MessageType(pdu[0] & 3)
@@ -324,6 +332,23 @@ class MessageAssembler:
Protocol.PacketType.SINGLE_PACKET,
Protocol.PacketType.START_PACKET,
):
# Both single and start packets carry the signal identifier in
# pdu[1]; start packets additionally carry the packet count in
# pdu[2]. Guard each access so a malformed remote frame can't
# crash the message assembler.
if len(pdu) < 2:
logger.warning(
'AVDTP %s packet too short (%d bytes); dropped',
packet_type.name,
len(pdu),
)
return
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
logger.warning(
'AVDTP START packet missing signal-packet count; dropped'
)
return
if self.message is not None:
# The previous message has not been terminated
logger.warning(
@@ -1453,8 +1478,23 @@ class Protocol(utils.EventEmitter):
handler = getattr(self, handler_name, None)
if handler:
try:
response = handler(message)
self.send_message(transaction_label, response)
result = handler(message)
if asyncio.iscoroutine(result):
async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
else:
@@ -1535,7 +1575,7 @@ class Protocol(utils.EventEmitter):
async def send_command(self, command: Message):
# TODO: support timeouts
# Send the command
(transaction_label, transaction_result) = await self.start_transaction()
transaction_label, transaction_result = await self.start_transaction()
self.send_message(transaction_label, command)
# Wait for the response
@@ -1600,14 +1640,14 @@ class Protocol(utils.EventEmitter):
async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid))
def on_discover_command(self, command: Discover_Command) -> Message | None:
async def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
]
return Discover_Response(endpoint_infos)
def on_get_capabilities_command(
async def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1616,7 +1656,7 @@ class Protocol(utils.EventEmitter):
return Get_Capabilities_Response(endpoint.capabilities)
def on_get_all_capabilities_command(
async def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1625,7 +1665,7 @@ class Protocol(utils.EventEmitter):
return Get_All_Capabilities_Response(endpoint.capabilities)
def on_set_configuration_command(
async def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1640,10 +1680,10 @@ class Protocol(utils.EventEmitter):
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream
result = stream.on_set_configuration_command(command.capabilities)
result = await stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response()
def on_get_configuration_command(
async def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1652,29 +1692,31 @@ class Protocol(utils.EventEmitter):
if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return endpoint.stream.on_get_configuration_command()
return await endpoint.stream.on_get_configuration_command()
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_reconfigure_command(command.capabilities)
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response()
def on_open_command(self, command: Open_Command) -> Message | None:
async def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_open_command()
result = await endpoint.stream.on_open_command()
return result or Open_Response()
def on_start_command(self, command: Start_Command) -> Message | None:
async def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1688,12 +1730,12 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_start_command()) is not None:
if (result := await endpoint.stream.on_start_command()) is not None:
return result
return Start_Response()
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1707,45 +1749,47 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_suspend_command()) is not None:
if (result := await endpoint.stream.on_suspend_command()) is not None:
return result
return Suspend_Response()
def on_close_command(self, command: Close_Command) -> Message | None:
async def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_close_command()
result = await endpoint.stream.on_close_command()
return result or Close_Response()
def on_abort_command(self, command: Abort_Command) -> Message | None:
async def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None:
return Abort_Response()
endpoint.stream.on_abort_command()
await endpoint.stream.on_abort_command()
return Abort_Response()
def on_security_control_command(
async def on_security_control_command(
self, command: Security_Control_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_security_control_command(command.data)
result = await endpoint.on_security_control_command(command.data)
return result or Security_Control_Response()
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_delayreport_command(command.delay)
result = await endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response()
@@ -1903,25 +1947,22 @@ class Stream:
await self.rtp_channel.disconnect()
self.rtp_channel = None
# Release the endpoint
self.local_endpoint.in_use = 0
self.change_state(State.IDLE)
def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_set_configuration_command(configuration)
result = await self.local_endpoint.on_set_configuration_command(configuration)
if result is not None:
return result
self.change_state(State.CONFIGURED)
return None
def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
if self.state not in (
State.CONFIGURED,
State.OPEN,
@@ -1929,25 +1970,25 @@ class Stream:
):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command()
return await self.local_endpoint.on_get_configuration_command()
def on_reconfigure_command(
async def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_reconfigure_command(configuration)
result = await self.local_endpoint.on_reconfigure_command(configuration)
if result is not None:
return result
return None
def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_open_command()
result = await self.local_endpoint.on_open_command()
if result is not None:
return result
@@ -1957,7 +1998,7 @@ class Stream:
self.change_state(State.OPEN)
return None
def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -1966,29 +2007,29 @@ class Stream:
logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_start_command()
result = await self.local_endpoint.on_start_command()
if result is not None:
return result
self.change_state(State.STREAMING)
return None
def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_suspend_command()
result = await self.local_endpoint.on_suspend_command()
if result is not None:
return result
self.change_state(State.OPEN)
return None
def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_close_command()
result = await self.local_endpoint.on_close_command()
if result is not None:
return result
@@ -2003,7 +2044,8 @@ class Stream:
return None
def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
if self.rtp_channel is None:
# No need to wait
self.change_state(State.IDLE)
@@ -2028,7 +2070,6 @@ class Stream:
def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< stream channel closed', 'magenta'))
self.local_endpoint.on_rtp_channel_close()
self.local_endpoint.in_use = 0
self.rtp_channel = None
if self.state in (State.CLOSING, State.ABORTING):
@@ -2053,7 +2094,6 @@ class Stream:
self.state = State.IDLE
local_endpoint.stream = self
local_endpoint.in_use = 1
def __str__(self) -> str:
return (
@@ -2063,14 +2103,16 @@ class Stream:
# -----------------------------------------------------------------------------
@dataclass
class StreamEndPoint:
class StreamEndPoint(abc.ABC):
seid: int
media_type: MediaType
tsep: StreamEndPointType
in_use: int
capabilities: Iterable[ServiceCapabilities]
@property
def in_use(self) -> int:
raise NotImplementedError()
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
@@ -2110,14 +2152,30 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
in_use: int,
capabilities: Iterable[ServiceCapabilities],
) -> None:
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
StreamEndPointProxy.__init__(self, protocol, seid)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self._in_use = in_use
self.capabilities = capabilities
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
@property
def in_use(self) -> int:
return self._in_use
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
stream: Stream | None
@property
def in_use(self) -> int:
if self.stream and self.stream.state != State.IDLE:
return 1
return 0
EVENT_CONFIGURATION = "configuration"
EVENT_OPEN = "open"
EVENT_START = "start"
@@ -2140,8 +2198,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
capabilities: Iterable[ServiceCapabilities],
configuration: Iterable[ServiceCapabilities] | None = None,
):
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
utils.EventEmitter.__init__(self)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.capabilities = capabilities
self.protocol = protocol
self.configuration = configuration if configuration is not None else []
self.stream = None
@@ -2155,13 +2218,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
async def close(self) -> None:
"""[Source Only] Handles when receiving close command."""
def on_reconfigure_command(
async def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities]
) -> Message | None:
del command # unused.
return None
def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
logger.debug(
@@ -2172,34 +2235,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
self.emit(self.EVENT_CONFIGURATION)
return None
def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration)
def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN)
return None
def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START)
return None
def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND)
return None
def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE)
return None
def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT)
return None
def on_delayreport_command(self, delay: int) -> Message | None:
async def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay)
return None
def on_security_control_command(self, data: bytes) -> Message | None:
async def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data)
return None
@@ -2227,12 +2290,12 @@ class LocalSource(LocalStreamEndPoint):
codec_capabilities,
] + list(other_capabilities)
super().__init__(
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SRC,
capabilities=capabilities,
configuration=capabilities,
)
self.packet_pump = packet_pump
@@ -2251,13 +2314,13 @@ class LocalSource(LocalStreamEndPoint):
self.emit(self.EVENT_STOP)
@override
def on_start_command(self) -> Message | None:
asyncio.create_task(self.start())
async def on_start_command(self) -> Message | None:
await self.start()
return None
@override
def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop())
async def on_suspend_command(self) -> Message | None:
await self.stop()
return None
@@ -2271,11 +2334,11 @@ class LocalSink(LocalStreamEndPoint):
codec_capabilities,
]
super().__init__(
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SNK,
capabilities=capabilities,
)
def on_rtp_channel_open(self) -> None:
+9 -11
View File
@@ -2343,6 +2343,9 @@ class Device(utils.CompositeEventEmitter):
_pending_cis: dict[int, tuple[int, int]]
gatt_service: gatt_service.GenericAttributeProfileService | None = None
keystore: KeyStore | None = None
inquiry_response: bytes | None = None
address_resolver: smp.AddressResolver | None = None
connect_own_address_type: hci.OwnAddressType | None = None
EVENT_ADVERTISEMENT = "advertisement"
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
@@ -2461,17 +2464,12 @@ class Device(utils.CompositeEventEmitter):
self.bis_links = {}
self.big_syncs = {}
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.classic_pending_accepts = {
hci.Address.ANY: []
} # Futures, by BD address OR [Futures] for hci.Address.ANY
self._cis_lock = asyncio.Lock()
# Own address type cache
self.connect_own_address_type = None
self.name = config.name
self.public_address = hci.Address.ANY
self.random_address = config.address
@@ -5618,8 +5616,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
"""
@@ -5638,7 +5636,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
) -> None:
"""
Send a notification to all the subscribers of an attribute.
@@ -5657,8 +5655,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
):
"""
@@ -5679,7 +5677,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
):
"""
Send an indication to all the subscribers of an attribute.
+24 -22
View File
@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
+12 -1
View File
@@ -68,6 +68,8 @@ class HfpProtocolError(ProtocolError):
# -----------------------------------------------------------------------------
class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
@@ -84,10 +86,19 @@ class HfpProtocol:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
data = data.decode('utf-8', errors='replace')
logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
+324 -264
View File
@@ -44,6 +44,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
# prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -148,32 +154,6 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
@@ -190,279 +170,354 @@ class DataElement:
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@classmethod
def nil(cls) -> DataElement:
return cls(cls.NIL, None)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@classmethod
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@classmethod
def unsigned_integer_8(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@classmethod
def unsigned_integer_16(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@classmethod
def unsigned_integer_32(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@classmethod
def signed_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@classmethod
def signed_integer_8(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@classmethod
def signed_integer_16(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@classmethod
def signed_integer_32(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@classmethod
def uuid(cls, value: core.UUID) -> DataElement:
return cls(cls.UUID, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@classmethod
def text_string(cls, value: bytes) -> DataElement:
return cls(cls.TEXT_STRING, value)
@staticmethod
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@classmethod
def boolean(cls, value: bool) -> DataElement:
return cls(cls.BOOLEAN, value)
@staticmethod
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@classmethod
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.SEQUENCE, value)
@staticmethod
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@classmethod
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.ALTERNATIVE, value)
@staticmethod
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@classmethod
def url(cls, value: str) -> DataElement:
return cls(cls.URL, value)
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
@classmethod
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2:
return struct.unpack('>H', data)[0]
@classmethod
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 4:
return struct.unpack('>I', data)[0]
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 8:
return struct.unpack('>Q', data)[0]
@classmethod
def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
def __bytes__(self) -> bytes:
# Return early if we have a cache
if self._bytes:
return self._bytes
if self.type == DataElement.NIL:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
match self.type:
case DataElement.NIL:
data = b''
case DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
elif self.value_size == 4:
data = struct.pack('>I', self.value)
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
elif self.value_size == 2:
data = struct.pack('>h', self.value)
elif self.value_size == 4:
data = struct.pack('>i', self.value)
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
match self.value_size:
case 1:
data = struct.pack('B', self.value)
case 2:
data = struct.pack('>H', self.value)
case 4:
data = struct.pack('>I', self.value)
case 8:
data = struct.pack('>Q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.SIGNED_INTEGER:
match self.value_size:
case 1:
data = struct.pack('b', self.value)
case 2:
data = struct.pack('>h', self.value)
case 4:
data = struct.pack('>i', self.value)
case 8:
data = struct.pack('>q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.UUID:
data = bytes(self.value)[::-1]
case DataElement.URL:
data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
size = len(data)
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
match self.type:
case DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
case (
DataElement.UNSIGNED_INTEGER
| DataElement.SIGNED_INTEGER
| DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
case (
DataElement.TEXT_STRING
| DataElement.SEQUENCE
| DataElement.ALTERNATIVE
| DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty=False, indentation=0):
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
prefix = ' ' * indentation
type_name = self.type.name
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
match self.type:
case DataElement.NIL:
value_string = ''
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
case _:
if isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})'
def __str__(self):
def __str__(self) -> str:
return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
@@ -594,7 +649,10 @@ class SDP_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
if len(pdu) != 5 + parameters_length:
logger.warning("Expect %d bytes, got %d", 5 + parameters_length, len(pdu))
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
@@ -616,9 +674,11 @@ class SDP_PDU:
def __bytes__(self):
if self._payload is None:
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
parameters = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
self._payload = (
struct.pack('>BHH', self.pdu_id, self.transaction_id, len(parameters))
+ parameters
)
return self._payload
@property
+3
View File
@@ -36,6 +36,7 @@ from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
InvalidPacketError,
PhysicalTransport,
ProtocolError,
)
@@ -215,6 +216,8 @@ class SMP_Command:
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
if not pdu:
raise InvalidPacketError("Empty SMP PDU")
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)
+25
View File
@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'pdu',
(
b'', # empty PDU — would IndexError on pdu[0]
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
),
)
def test_message_assembler_truncated_pdu(pdu: bytes):
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
same DoS class as #912 (ATT empty PDU). The assembler is required to
log + drop and stay alive so the L2CAP channel survives."""
completed = []
def callback(transaction_label, message):
completed.append((transaction_label, message))
assembler = avdtp.MessageAssembler(callback)
# Must not raise; nothing should be delivered to callback either.
assembler.on_pdu(pdu)
assert not completed
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex(
+52
View File
@@ -18,9 +18,11 @@
import asyncio
import logging
import os
import re
import pytest
from bumble import sdp
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
@@ -206,6 +208,16 @@ def sdp_records(record_count=1):
}
# -----------------------------------------------------------------------------
def test_pdu_parameter_length(caplog) -> None:
caplog.set_level(logging.WARNING)
pdu = sdp.SDP_ErrorResponse(
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
)
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search():
@@ -428,3 +440,43 @@ async def run():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
# -----------------------------------------------------------------------------
def test_nested_sequence_recursion_guard():
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
the parser with RecursionError. Instead a ValueError is raised once the
configured nesting limit is exceeded.
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
without a depth limit. A malicious SDP peer could craft a PDU exceeding
Pythons default recursion limit (~1000 frames) and crash the host.
"""
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
inner = b"\x35\x00" # empty SEQUENCE terminator
for _ in range(1500):
size = len(inner)
if size >= 65535:
break
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
with pytest.raises(ValueError, match="nesting exceeds max depth"):
DataElement.from_bytes(inner)
def test_nested_sequence_within_limit_still_works():
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
leaf = DataElement.unsigned_integer(1, value_size=2)
payload = leaf
for _ in range(16): # under the 32-depth limit
payload = DataElement.sequence([payload])
raw = bytes(payload)
parsed = DataElement.from_bytes(raw)
# Walk back down to confirm structural integrity preserved
cur = parsed
for _ in range(16):
assert cur.type == DataElement.SEQUENCE
cur = cur.value[0]
assert cur.type == DataElement.UNSIGNED_INTEGER
assert cur.value == 1