Compare commits

..

22 Commits

Author SHA1 Message Date
Josh Wu
8988a85245 Merge pull request #919 from zxzxwu/sdp
SDP: Move parser functions to parser class
2026-04-29 13:21:13 +08:00
Josh Wu
0813da2278 SDP: Move parser functions to parser class 2026-04-28 13:27:50 +08:00
Gilles Boccon-Gibod
a1ff183d44 Merge pull request #915 from dlech/notify-subscribers-type-hints
improve type hints for notify/indicate subscriber(s) methods
2026-04-27 21:45:38 +02:00
Gilles Boccon-Gibod
7adf44eddf Merge pull request #916 from dlech/fix-crash-in-attribute-repr
fix crash in `bumble.att.Attribute.__repr__`
2026-04-27 21:41:41 +02:00
Josh Wu
05accbf805 Merge pull request #918 from ibondarenko1/fix/avdtp-empty-pdu-guard
avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
2026-04-27 10:01:51 +08:00
Josh Wu
80f54f2a09 Merge pull request #917 from dlech/fix-regex-with-backslash
Fix regex syntax warning in sdp_test.py.
2026-04-27 09:55:36 +08:00
ibondarenko1
07b5e33e09 avdtp: address review nits — use truthy checks
Per @zxzxwu review on #918:
- bumble/avdtp.py: replace `if len(pdu) < 1:` with `if not pdu:`
- tests/avdtp_test.py: replace `assert completed == []` with
  `assert not completed`

Both are idiomatic Python truthy checks; behavior identical.
2026-04-26 18:49:55 -07:00
ibondarenko1
b874e26a4f avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
A remote peer can send an AVDTP frame shorter than the assembler expects.
The current MessageAssembler.on_pdu() unconditionally accesses pdu[0],
pdu[1], and (for START packets) pdu[2], so a 0-, 1-, or 2-byte frame
raises IndexError. The exception propagates up through L2CAP's read loop
and tears down the channel — same DoS class as #912 (empty ATT PDU) and
#914 (unbounded SDP recursion).

Fix: validate length before each access. Empty PDUs and packets shorter
than the type-specific minimum are logged and dropped; the assembler
stays alive so the L2CAP channel is not torn down.

- bumble/avdtp.py: length guards in MessageAssembler.on_pdu before
  accessing pdu[0], pdu[1], pdu[2].
- tests/avdtp_test.py: regression test covering empty PDU, 1-byte SINGLE,
  1-byte START, 2-byte START — all four would have raised IndexError
  pre-fix; assembler now drops without raising.
2026-04-26 18:16:15 -07:00
David Lechner
baa5257780 improve type hints for notify/indicate subscriber(s) methods
Pyright expects generic type parameters to be specified for the
Attribute class, otherwise it treats the type as Unknown which can
trigger reportUnknownMemberType errors.

This can be solved by using a generic type parameter for these methods
which also has the benefit of making sure that the value parameter has
the correct type for the attribute.

In some cases, a new local `value_as_bytes` variable is needed to avoid
type errors and makes the code less confusing by not overwriting the
original `value` variable.
2026-04-26 09:43:40 -05:00
David Lechner
a91ea9110c Fix regex syntax warning in sdp_test.py.
Change regex match string to raw string to avoid syntax warning:

    tests/sdp_test.py:218: SyntaxWarning: invalid escape sequence '\d'
    assert not re.search("Expect \d+ bytes, got \d+", caplog.text)

In the future, this will become an error, so we should fix it now.
2026-04-26 09:31:18 -05:00
Josh Wu
1686c5b11b Merge pull request #914 from ibondarenko1/fix/sdp-recursion-depth-limit
sdp: bound DataElement parse recursion to prevent RecursionError DoS
2026-04-26 17:22:59 +08:00
David Lechner
d9481992bb fix crash in bumble.att.Attribute.__repr__
If an attribute does not contains a bytes value, it would crash with
something like:

    AttributeError: 'NoneType' object has no attribute 'hex'

Clearly, the intention here was to use `value_str` to avoid this
possibility.
2026-04-25 17:01:25 -05:00
ibondarenko1
16d0ed56cf sdp: address review nits (import at top, InvalidPacketError)
- bumble/sdp.py: replace raise ValueError with raise InvalidPacketError
  in DataElement.list_from_bytes depth guard. InvalidPacketError
  already imported at line 34 and extends ValueError so the existing
  regression test continues to match.
- tests/sdp_test.py: remove duplicate 'import pytest' inside
  test_nested_sequence_recursion_guard; pytest already imported at
  module top (line 23).

Threading.local counter left as-is per zxzxwu's 'leave it here and
refactor later' comment on the PR.
2026-04-24 11:42:49 -07:00
Ievgen Bondarenko
c55eb156b8 sdp: fix lint formatting (black: blank line after import pytest) 2026-04-24 00:06:56 -07:00
ibondarenko1
8614881fb3 sdp: bound DataElement parse recursion to prevent RecursionError DoS
DataElement.from_bytes -> list_from_bytes -> (SEQUENCE/ALTERNATIVE
constructor dispatches back to list_from_bytes) had no depth limit. A
malicious SDP peer could send a PDU of a few kilobytes containing ~1000
nested SEQUENCE tags and exhaust the Python recursion stack, crashing the
host with an unhandled RecursionError propagating out of the SDP handler.

Reachable via: any remote Bluetooth device that Bumble performs SDP
service discovery against (default during Classic connection setup).

Same family as PR #912 (ATT_PDU.from_bytes empty PDU IndexError) - remote
unchecked-input parser crash in the Bluetooth stack.

Fix: thread-local depth counter, cap nesting at 32 (well above anything a
legitimate service record uses). Added two regression tests covering the
deep-nesting reject path and normal 16-level-nested SEQUENCE parsing.

Reproducer (4.5 KB payload, deterministic crash on 0.0.228):

    from bumble.sdp import DataElement
    inner = b"\x35\x00"
    for _ in range(1500):
        size = len(inner)
        if size < 65535:
            inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
    DataElement.from_bytes(inner)  # RecursionError before fix

Signed-off-by: ibondarenko1 <ibondarenko1@users.noreply.github.com>
2026-04-23 00:53:06 -07:00
Josh Wu
27d02ef18d Merge pull request #913 from zxzxwu/sdp
SDP: Fix wrong parameter size
2026-04-20 16:32:37 +08:00
Josh Wu
c0725e2a4a SDP: Fix wrong parameter size 2026-04-20 16:23:19 +08:00
Josh Wu
bf0784dde4 Merge pull request #912 from ibondarenko1/fix/empty-pdu-crash
fix: add input validation to prevent remote crash from empty/malforme…
2026-04-20 14:36:48 +08:00
Ievgen Bondarenko
444f43f6a3 fix: address review feedback - use InvalidPacketError and abort on buffer overflow
- att.py: raise core.InvalidPacketError instead of generic ValueError
- smp.py: raise core.InvalidPacketError instead of generic ValueError
- hfp.py: add MAX_BUFFER_SIZE class constant (64KB)
- hfp.py: drop incoming data when it would overflow buffer instead of
  truncating, preserving existing partial-packet state

Per review comments on PR #912 by @zxzxwu.
2026-04-16 11:24:09 -07:00
Gilles Boccon-Gibod
2420c47cf1 Merge pull request #911 from google/gbg/issue-910
release command semaphore after timeout
2026-04-16 18:11:57 +02:00
Ievgen Bondarenko
0a78e7506b fix: add input validation to prevent remote crash from empty/malformed PDUs
Add length checks in from_bytes() for ATT and SMP protocol parsers
to prevent IndexError crashes from empty PDUs sent by remote Bluetooth
devices. Also add buffer size limit and UTF-8 error handling in HFP
protocol to prevent memory exhaustion and decode crashes.

- bumble/att.py: validate PDU is non-empty before accessing pdu[0]
- bumble/smp.py: validate PDU is non-empty before accessing pdu[0]
- bumble/hfp.py: limit buffer to 64KB, handle invalid UTF-8 gracefully

These issues can be triggered by a remote Bluetooth device sending
malformed packets, causing denial of service on the host.
2026-04-16 01:43:41 -07:00
Gilles Boccon-Gibod
f7cc6f6657 release command semaphore after timeout 2026-04-15 16:54:54 +02:00
11 changed files with 497 additions and 302 deletions

View File

@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
from bumble.hci import HCI_Object
# -----------------------------------------------------------------------------
@@ -249,6 +249,8 @@ class ATT_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
if not pdu:
raise InvalidPacketError("Empty ATT PDU")
op_code = pdu[0]
subclass = ATT_PDU.pdu_classes.get(op_code)
@@ -1081,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}'
value_string = f', value={value_str}'
else:
value_string = ''
return (

View File

@@ -311,6 +311,13 @@ class MessageAssembler:
def on_pdu(self, pdu: bytes) -> None:
self.packet_count += 1
# Drop empty PDUs sent by remote — accessing pdu[0] below would
# raise IndexError, propagating up to the L2CAP read loop and
# tearing down the channel. Same class as #912 (ATT empty PDU).
if not pdu:
logger.warning('AVDTP message assembler: empty PDU dropped')
return
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
message_type = Message.MessageType(pdu[0] & 3)
@@ -324,6 +331,23 @@ class MessageAssembler:
Protocol.PacketType.SINGLE_PACKET,
Protocol.PacketType.START_PACKET,
):
# Both single and start packets carry the signal identifier in
# pdu[1]; start packets additionally carry the packet count in
# pdu[2]. Guard each access so a malformed remote frame can't
# crash the message assembler.
if len(pdu) < 2:
logger.warning(
'AVDTP %s packet too short (%d bytes); dropped',
packet_type.name,
len(pdu),
)
return
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
logger.warning(
'AVDTP START packet missing signal-packet count; dropped'
)
return
if self.message is not None:
# The previous message has not been terminated
logger.warning(

View File

@@ -5618,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
"""
@@ -5638,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
) -> None:
"""
Send a notification to all the subscribers of an attribute.
@@ -5657,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
):
"""
@@ -5679,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
):
"""
Send an indication to all the subscribers of an attribute.

View File

@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)

View File

@@ -68,6 +68,8 @@ class HfpProtocolError(ProtocolError):
# -----------------------------------------------------------------------------
class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
@@ -84,10 +86,19 @@ class HfpProtocol:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
data = data.decode('utf-8', errors='replace')
logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:

View File

@@ -692,10 +692,8 @@ class Host(utils.EventEmitter):
finally:
self.pending_command = None
self.pending_response = None
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
if response is None or (
response.num_hci_command_packets and self.command_semaphore.locked()
):
self.command_semaphore.release()

View File

@@ -44,6 +44,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
# prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -148,32 +154,6 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
@@ -190,279 +170,354 @@ class DataElement:
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@classmethod
def nil(cls) -> DataElement:
return cls(cls.NIL, None)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@classmethod
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@classmethod
def unsigned_integer_8(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@classmethod
def unsigned_integer_16(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@classmethod
def unsigned_integer_32(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@classmethod
def signed_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@classmethod
def signed_integer_8(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@classmethod
def signed_integer_16(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@classmethod
def signed_integer_32(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@classmethod
def uuid(cls, value: core.UUID) -> DataElement:
return cls(cls.UUID, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@classmethod
def text_string(cls, value: bytes) -> DataElement:
return cls(cls.TEXT_STRING, value)
@staticmethod
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@classmethod
def boolean(cls, value: bool) -> DataElement:
return cls(cls.BOOLEAN, value)
@staticmethod
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@classmethod
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.SEQUENCE, value)
@staticmethod
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@classmethod
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.ALTERNATIVE, value)
@staticmethod
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@classmethod
def url(cls, value: str) -> DataElement:
return cls(cls.URL, value)
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
@classmethod
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2:
return struct.unpack('>H', data)[0]
@classmethod
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 4:
return struct.unpack('>I', data)[0]
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 8:
return struct.unpack('>Q', data)[0]
@classmethod
def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
def __bytes__(self) -> bytes:
# Return early if we have a cache
if self._bytes:
return self._bytes
if self.type == DataElement.NIL:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
match self.type:
case DataElement.NIL:
data = b''
case DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
elif self.value_size == 4:
data = struct.pack('>I', self.value)
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
elif self.value_size == 2:
data = struct.pack('>h', self.value)
elif self.value_size == 4:
data = struct.pack('>i', self.value)
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
match self.value_size:
case 1:
data = struct.pack('B', self.value)
case 2:
data = struct.pack('>H', self.value)
case 4:
data = struct.pack('>I', self.value)
case 8:
data = struct.pack('>Q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.SIGNED_INTEGER:
match self.value_size:
case 1:
data = struct.pack('b', self.value)
case 2:
data = struct.pack('>h', self.value)
case 4:
data = struct.pack('>i', self.value)
case 8:
data = struct.pack('>q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.UUID:
data = bytes(self.value)[::-1]
case DataElement.URL:
data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
size = len(data)
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
match self.type:
case DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
case (
DataElement.UNSIGNED_INTEGER
| DataElement.SIGNED_INTEGER
| DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
case (
DataElement.TEXT_STRING
| DataElement.SEQUENCE
| DataElement.ALTERNATIVE
| DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty=False, indentation=0):
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
prefix = ' ' * indentation
type_name = self.type.name
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
match self.type:
case DataElement.NIL:
value_string = ''
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
case _:
if isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})'
def __str__(self):
def __str__(self) -> str:
return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
@@ -594,7 +649,10 @@ class SDP_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
if len(pdu) != 5 + parameters_length:
logger.warning("Expect %d bytes, got %d", 5 + parameters_length, len(pdu))
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
@@ -616,9 +674,11 @@ class SDP_PDU:
def __bytes__(self):
if self._payload is None:
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
parameters = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
self._payload = (
struct.pack('>BHH', self.pdu_id, self.transaction_id, len(parameters))
+ parameters
)
return self._payload
@property

View File

@@ -36,6 +36,7 @@ from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
InvalidPacketError,
PhysicalTransport,
ProtocolError,
)
@@ -215,6 +216,8 @@ class SMP_Command:
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
if not pdu:
raise InvalidPacketError("Empty SMP PDU")
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)

View File

@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'pdu',
(
b'', # empty PDU — would IndexError on pdu[0]
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
),
)
def test_message_assembler_truncated_pdu(pdu: bytes):
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
same DoS class as #912 (ATT empty PDU). The assembler is required to
log + drop and stay alive so the L2CAP channel survives."""
completed = []
def callback(transaction_label, message):
completed.append((transaction_label, message))
assembler = avdtp.MessageAssembler(callback)
# Must not raise; nothing should be delivered to callback either.
assembler.on_pdu(pdu)
assert not completed
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex(

View File

@@ -171,14 +171,15 @@ class Source:
class Sink:
response: HCI_Event
response: HCI_Event | None
def __init__(self, source: Source, response: HCI_Event) -> None:
def __init__(self, source: Source, response: HCI_Event | None) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
self.source.sink.on_packet(bytes(self.response))
if self.response is not None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
@@ -228,6 +229,23 @@ async def test_send_sync_command() -> None:
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_sync_command_timeout() -> None:
source = Source()
sink = Sink(source, None)
host = Host(source, sink)
host.ready = True
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
# The sending semaphore should have been released, so this should not block
# indefinitely
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()

View File

@@ -18,9 +18,11 @@
import asyncio
import logging
import os
import re
import pytest
from bumble import sdp
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
@@ -206,6 +208,16 @@ def sdp_records(record_count=1):
}
# -----------------------------------------------------------------------------
def test_pdu_parameter_length(caplog) -> None:
caplog.set_level(logging.WARNING)
pdu = sdp.SDP_ErrorResponse(
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
)
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search():
@@ -428,3 +440,43 @@ async def run():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
# -----------------------------------------------------------------------------
def test_nested_sequence_recursion_guard():
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
the parser with RecursionError. Instead a ValueError is raised once the
configured nesting limit is exceeded.
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
without a depth limit. A malicious SDP peer could craft a PDU exceeding
Pythons default recursion limit (~1000 frames) and crash the host.
"""
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
inner = b"\x35\x00" # empty SEQUENCE terminator
for _ in range(1500):
size = len(inner)
if size >= 65535:
break
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
with pytest.raises(ValueError, match="nesting exceeds max depth"):
DataElement.from_bytes(inner)
def test_nested_sequence_within_limit_still_works():
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
leaf = DataElement.unsigned_integer(1, value_size=2)
payload = leaf
for _ in range(16): # under the 32-depth limit
payload = DataElement.sequence([payload])
raw = bytes(payload)
parsed = DataElement.from_bytes(raw)
# Walk back down to confirm structural integrity preserved
cur = parsed
for _ in range(16):
assert cur.type == DataElement.SEQUENCE
cur = cur.value[0]
assert cur.type == DataElement.UNSIGNED_INTEGER
assert cur.value == 1