From 0813da2278439feaf29fde88a835bab71a0e5565 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 27 Apr 2026 10:54:32 +0800 Subject: [PATCH] SDP: Move parser functions to parser class --- bumble/sdp.py | 581 +++++++++++++++++++++++++++----------------------- 1 file changed, 309 insertions(+), 272 deletions(-) diff --git a/bumble/sdp.py b/bumble/sdp.py index 6ffce32..f0712dd 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -20,7 +20,6 @@ from __future__ import annotations import asyncio import logging import struct -import threading from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar @@ -49,7 +48,6 @@ logger = logging.getLogger(__name__) # prevent a malicious peer from crashing the process via a deeply nested PDU. # 32 levels is well beyond anything a legitimate service record uses. _MAX_DATA_ELEMENT_NESTING = 32 -_parse_state = threading.local() # ----------------------------------------------------------------------------- @@ -156,32 +154,6 @@ class DataElement: ALTERNATIVE = Type.ALTERNATIVE URL = Type.URL - TYPE_CONSTRUCTORS = { - NIL: lambda x: DataElement(DataElement.NIL, None), - UNSIGNED_INTEGER: lambda x, y: DataElement( - DataElement.UNSIGNED_INTEGER, - DataElement.unsigned_integer_from_bytes(x), - value_size=y, - ), - SIGNED_INTEGER: lambda x, y: DataElement( - DataElement.SIGNED_INTEGER, - DataElement.signed_integer_from_bytes(x), - value_size=y, - ), - UUID: lambda x: DataElement( - DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x))) - ), - TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x), - BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), - SEQUENCE: lambda x: DataElement( - DataElement.SEQUENCE, DataElement.list_from_bytes(x) - ), - ALTERNATIVE: lambda x: DataElement( - DataElement.ALTERNATIVE, DataElement.list_from_bytes(x) - ), - URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')), - } - type: Type value: Any value_size: int | None = None @@ -198,289 +170,354 @@ class DataElement: 'integer types must have a value size specified' ) - @staticmethod - def nil() -> DataElement: - return DataElement(DataElement.NIL, None) + @classmethod + def nil(cls) -> DataElement: + return cls(cls.NIL, None) - @staticmethod - def unsigned_integer(value: int, value_size: int) -> DataElement: - return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size) + @classmethod + def unsigned_integer(cls, value: int, value_size: int) -> DataElement: + return cls(cls.UNSIGNED_INTEGER, value, value_size) - @staticmethod - def unsigned_integer_8(value: int) -> DataElement: - return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1) + @classmethod + def unsigned_integer_8(cls, value: int) -> DataElement: + return cls(cls.UNSIGNED_INTEGER, value, value_size=1) - @staticmethod - def unsigned_integer_16(value: int) -> DataElement: - return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2) + @classmethod + def unsigned_integer_16(cls, value: int) -> DataElement: + return cls(cls.UNSIGNED_INTEGER, value, value_size=2) - @staticmethod - def unsigned_integer_32(value: int) -> DataElement: - return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4) + @classmethod + def unsigned_integer_32(cls, value: int) -> DataElement: + return cls(cls.UNSIGNED_INTEGER, value, value_size=4) - @staticmethod - def signed_integer(value: int, value_size: int) -> DataElement: - return DataElement(DataElement.SIGNED_INTEGER, value, value_size) + @classmethod + def signed_integer(cls, value: int, value_size: int) -> DataElement: + return cls(cls.SIGNED_INTEGER, value, value_size) - @staticmethod - def signed_integer_8(value: int) -> DataElement: - return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1) + @classmethod + def signed_integer_8(cls, value: int) -> DataElement: + return cls(cls.SIGNED_INTEGER, value, value_size=1) - @staticmethod - def signed_integer_16(value: int) -> DataElement: - return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2) + @classmethod + def signed_integer_16(cls, value: int) -> DataElement: + return cls(cls.SIGNED_INTEGER, value, value_size=2) - @staticmethod - def signed_integer_32(value: int) -> DataElement: - return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4) + @classmethod + def signed_integer_32(cls, value: int) -> DataElement: + return cls(cls.SIGNED_INTEGER, value, value_size=4) - @staticmethod - def uuid(value: core.UUID) -> DataElement: - return DataElement(DataElement.UUID, value) + @classmethod + def uuid(cls, value: core.UUID) -> DataElement: + return cls(cls.UUID, value) - @staticmethod - def text_string(value: bytes) -> DataElement: - return DataElement(DataElement.TEXT_STRING, value) + @classmethod + def text_string(cls, value: bytes) -> DataElement: + return cls(cls.TEXT_STRING, value) - @staticmethod - def boolean(value: bool) -> DataElement: - return DataElement(DataElement.BOOLEAN, value) + @classmethod + def boolean(cls, value: bool) -> DataElement: + return cls(cls.BOOLEAN, value) - @staticmethod - def sequence(value: Iterable[DataElement]) -> DataElement: - return DataElement(DataElement.SEQUENCE, value) + @classmethod + def sequence(cls, value: Iterable[DataElement]) -> DataElement: + return cls(cls.SEQUENCE, value) - @staticmethod - def alternative(value: Iterable[DataElement]) -> DataElement: - return DataElement(DataElement.ALTERNATIVE, value) + @classmethod + def alternative(cls, value: Iterable[DataElement]) -> DataElement: + return cls(cls.ALTERNATIVE, value) - @staticmethod - def url(value: str) -> DataElement: - return DataElement(DataElement.URL, value) + @classmethod + def url(cls, value: str) -> DataElement: + return cls(cls.URL, value) - @staticmethod - def unsigned_integer_from_bytes(data): - if len(data) == 1: - return data[0] + @classmethod + def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int: + match length: + case 1: + return data[offset] + case 2: + return struct.unpack_from('>H', data, offset)[0] + case 4: + return struct.unpack_from('>I', data, offset)[0] + case 8: + return struct.unpack_from('>Q', data, offset)[0] + case invalid_length: + raise InvalidPacketError(f'invalid integer length {invalid_length}') - if len(data) == 2: - return struct.unpack('>H', data)[0] + @classmethod + def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int: + match length: + case 1: + return struct.unpack_from('b', data, offset)[0] + case 2: + return struct.unpack_from('>h', data, offset)[0] + case 4: + return struct.unpack_from('>i', data, offset)[0] + case 8: + return struct.unpack_from('>q', data, offset)[0] + case invalid_length: + raise InvalidPacketError(f'invalid integer length {invalid_length}') - if len(data) == 4: - return struct.unpack('>I', data)[0] + @classmethod + def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]: + parser = DataElementParser(data, offset) + element = parser.parse_next() + return parser.offset, element - if len(data) == 8: - return struct.unpack('>Q', data)[0] + @classmethod + def from_bytes(cls, data: bytes) -> DataElement: + return DataElementParser(data).parse_next() - raise InvalidPacketError(f'invalid integer length {len(data)}') - - @staticmethod - def signed_integer_from_bytes(data): - if len(data) == 1: - return struct.unpack('b', data)[0] - - if len(data) == 2: - return struct.unpack('>h', data)[0] - - if len(data) == 4: - return struct.unpack('>i', data)[0] - - if len(data) == 8: - return struct.unpack('>q', data)[0] - - raise InvalidPacketError(f'invalid integer length {len(data)}') - - @staticmethod - def list_from_bytes(data): - depth = getattr(_parse_state, "depth", 0) - if depth >= _MAX_DATA_ELEMENT_NESTING: - raise InvalidPacketError( - f"SDP data element nesting exceeds max depth " - f"({_MAX_DATA_ELEMENT_NESTING})" - ) - _parse_state.depth = depth + 1 - try: - elements = [] - while data: - element = DataElement.from_bytes(data) - elements.append(element) - data = data[len(bytes(element)) :] - return elements - finally: - _parse_state.depth = depth - - @staticmethod - def parse_from_bytes(data, offset): - element = DataElement.from_bytes(data[offset:]) - return offset + len(bytes(element)), element - - @staticmethod - def from_bytes(data): - element_type = data[0] >> 3 - size_index = data[0] & 7 - value_offset = 0 - if size_index == 0: - if element_type == DataElement.NIL: - value_size = 0 - else: - value_size = 1 - elif size_index == 1: - value_size = 2 - elif size_index == 2: - value_size = 4 - elif size_index == 3: - value_size = 8 - elif size_index == 4: - value_size = 16 - elif size_index == 5: - value_size = data[1] - value_offset = 1 - elif size_index == 6: - value_size = struct.unpack('>H', data[1:3])[0] - value_offset = 2 - else: # size_index == 7 - value_size = struct.unpack('>I', data[1:5])[0] - value_offset = 4 - - value_data = data[1 + value_offset : 1 + value_offset + value_size] - constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type) - if constructor: - if element_type in ( - DataElement.UNSIGNED_INTEGER, - DataElement.SIGNED_INTEGER, - ): - result = constructor(value_data, value_size) - else: - result = constructor(value_data) - else: - result = DataElement(element_type, value_data) - result._bytes = data[ - : 1 + value_offset + value_size - ] # Keep a copy so we can re-serialize to an exact replica - return result - - def __bytes__(self): + def __bytes__(self) -> bytes: # Return early if we have a cache if self._bytes: return self._bytes - if self.type == DataElement.NIL: - data = b'' - elif self.type == DataElement.UNSIGNED_INTEGER: - if self.value < 0: - raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative') + match self.type: + case DataElement.NIL: + data = b'' + case DataElement.UNSIGNED_INTEGER: + if self.value < 0: + raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative') - if self.value_size == 1: - data = struct.pack('B', self.value) - elif self.value_size == 2: - data = struct.pack('>H', self.value) - elif self.value_size == 4: - data = struct.pack('>I', self.value) - elif self.value_size == 8: - data = struct.pack('>Q', self.value) - else: - raise InvalidArgumentError('invalid value_size') - elif self.type == DataElement.SIGNED_INTEGER: - if self.value_size == 1: - data = struct.pack('b', self.value) - elif self.value_size == 2: - data = struct.pack('>h', self.value) - elif self.value_size == 4: - data = struct.pack('>i', self.value) - elif self.value_size == 8: - data = struct.pack('>q', self.value) - else: - raise InvalidArgumentError('invalid value_size') - elif self.type == DataElement.UUID: - data = bytes(reversed(bytes(self.value))) - elif self.type == DataElement.URL: - data = self.value.encode('utf8') - elif self.type == DataElement.BOOLEAN: - data = bytes([1 if self.value else 0]) - elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): - data = b''.join([bytes(element) for element in self.value]) - else: - data = self.value + match self.value_size: + case 1: + data = struct.pack('B', self.value) + case 2: + data = struct.pack('>H', self.value) + case 4: + data = struct.pack('>I', self.value) + case 8: + data = struct.pack('>Q', self.value) + case invalid_length: + raise InvalidArgumentError( + f'invalid value_size of {invalid_length}' + ) + case DataElement.SIGNED_INTEGER: + match self.value_size: + case 1: + data = struct.pack('b', self.value) + case 2: + data = struct.pack('>h', self.value) + case 4: + data = struct.pack('>i', self.value) + case 8: + data = struct.pack('>q', self.value) + case invalid_length: + raise InvalidArgumentError( + f'invalid value_size of {invalid_length}' + ) + case DataElement.UUID: + data = bytes(self.value)[::-1] + case DataElement.URL: + data = self.value.encode('utf8') + case DataElement.BOOLEAN: + data = bytes([1 if self.value else 0]) + case DataElement.SEQUENCE | DataElement.ALTERNATIVE: + data = b''.join([bytes(element) for element in self.value]) + case _: + data = self.value size = len(data) size_bytes = b'' - if self.type == DataElement.NIL: - if size != 0: - raise InvalidArgumentError('NIL must be empty') - size_index = 0 - elif self.type in ( - DataElement.UNSIGNED_INTEGER, - DataElement.SIGNED_INTEGER, - DataElement.UUID, - ): - if size <= 1: + match self.type: + case DataElement.NIL: + if size != 0: + raise InvalidArgumentError('NIL must be empty') size_index = 0 - elif size == 2: - size_index = 1 - elif size == 4: - size_index = 2 - elif size == 8: - size_index = 3 - elif size == 16: - size_index = 4 - else: - raise InvalidArgumentError('invalid data size') - elif self.type in ( - DataElement.TEXT_STRING, - DataElement.SEQUENCE, - DataElement.ALTERNATIVE, - DataElement.URL, - ): - if size <= 0xFF: - size_index = 5 - size_bytes = bytes([size]) - elif size <= 0xFFFF: - size_index = 6 - size_bytes = struct.pack('>H', size) - elif size <= 0xFFFFFFFF: - size_index = 7 - size_bytes = struct.pack('>I', size) - else: - raise InvalidArgumentError('invalid data size') - elif self.type == DataElement.BOOLEAN: - if size != 1: - raise InvalidArgumentError('boolean must be 1 byte') - size_index = 0 - else: - raise RuntimeError("internal error - self.type not supported") + case ( + DataElement.UNSIGNED_INTEGER + | DataElement.SIGNED_INTEGER + | DataElement.UUID + ): + if size <= 1: + size_index = 0 + elif size == 2: + size_index = 1 + elif size == 4: + size_index = 2 + elif size == 8: + size_index = 3 + elif size == 16: + size_index = 4 + else: + raise InvalidArgumentError('invalid data size') + case ( + DataElement.TEXT_STRING + | DataElement.SEQUENCE + | DataElement.ALTERNATIVE + | DataElement.URL + ): + if size <= 0xFF: + size_index = 5 + size_bytes = bytes([size]) + elif size <= 0xFFFF: + size_index = 6 + size_bytes = struct.pack('>H', size) + elif size <= 0xFFFFFFFF: + size_index = 7 + size_bytes = struct.pack('>I', size) + else: + raise InvalidArgumentError('invalid data size') + case DataElement.BOOLEAN: + if size != 1: + raise InvalidArgumentError('boolean must be 1 byte') + size_index = 0 + case unsupported_type: + raise core.InvalidPacketError( + f"internal error - {unsupported_type} not supported" + ) self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data return self._bytes - def to_string(self, pretty=False, indentation=0): + def to_string(self, pretty: bool = False, indentation: int = 0) -> str: prefix = ' ' * indentation type_name = self.type.name - if self.type == DataElement.NIL: - value_string = '' - elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): - container_separator = '\n' if pretty else '' - element_separator = '\n' if pretty else ',' - elements = [ - element.to_string(pretty, indentation + 1 if pretty else 0) - for element in self.value - ] - value_string = ( - f'[{container_separator}' - f'{element_separator.join(elements)}' - f'{container_separator}{prefix}]' - ) - elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): - value_string = f'{self.value}#{self.value_size}' - elif isinstance(self.value, DataElement): - value_string = self.value.to_string(pretty, indentation) - else: - value_string = str(self.value) + match self.type: + case DataElement.NIL: + value_string = '' + case DataElement.SEQUENCE | DataElement.ALTERNATIVE: + container_separator = '\n' if pretty else '' + element_separator = '\n' if pretty else ',' + elements = [ + element.to_string(pretty, indentation + 1 if pretty else 0) + for element in self.value + ] + value_string = ( + f'[{container_separator}' + f'{element_separator.join(elements)}' + f'{container_separator}{prefix}]' + ) + case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER: + value_string = f'{self.value}#{self.value_size}' + case _: + if isinstance(self.value, DataElement): + value_string = self.value.to_string(pretty, indentation) + else: + value_string = str(self.value) return f'{prefix}{type_name}({value_string})' - def __str__(self): + def __str__(self) -> str: return self.to_string() +class DataElementParser: + def __init__( + self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING + ) -> None: + self.data = data + self.offset = offset + self.depth = 0 + self.max_depth = max_depth + + def parse_next(self) -> DataElement: + if self.offset >= len(self.data): + raise core.InvalidStateError( + f"offset {self.offset} exceeds len(data) {len(self.data)}" + ) + start_offset = self.offset + element_type = DataElement.Type(self.data[self.offset] >> 3) + size_index = self.data[self.offset] & 7 + self.offset += 1 + + value_size: int + match size_index: + case 0: + if element_type == DataElement.NIL: + value_size = 0 + else: + value_size = 1 + case 1: + value_size = 2 + case 2: + value_size = 4 + case 3: + value_size = 8 + case 4: + value_size = 16 + case 5: + value_size = self.data[self.offset] + self.offset += 1 + case 6: + value_size = struct.unpack_from('>H', self.data, self.offset)[0] + self.offset += 2 + case 7: + value_size = struct.unpack_from('>I', self.data, self.offset)[0] + self.offset += 4 + case _: + raise core.UnreachableError() + + value_start = self.offset + value_end = self.offset + value_size + + match element_type: + case DataElement.NIL: + result = DataElement(DataElement.NIL, None) + case DataElement.UNSIGNED_INTEGER: + result = DataElement( + DataElement.UNSIGNED_INTEGER, + DataElement.unsigned_integer_from_bytes( + self.data, value_start, value_size + ), + value_size=value_size, + ) + case DataElement.SIGNED_INTEGER: + result = DataElement( + DataElement.SIGNED_INTEGER, + DataElement.signed_integer_from_bytes( + self.data, value_start, value_size + ), + value_size=value_size, + ) + case DataElement.UUID: + result = DataElement( + DataElement.UUID, + core.UUID.from_bytes(self.data[value_start:value_end][::-1]), + ) + case DataElement.TEXT_STRING: + result = DataElement( + DataElement.TEXT_STRING, self.data[value_start:value_end] + ) + case DataElement.BOOLEAN: + result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1) + case DataElement.SEQUENCE | DataElement.ALTERNATIVE: + self.offset = value_start + result = DataElement( + element_type, + self._list_from_bytes(value_end), + ) + if self.offset != value_end: + logger.warning( + "Expect parsing until offset %d, but ends at %d", + value_end, + self.offset, + ) + case DataElement.URL: + result = DataElement( + DataElement.URL, self.data[value_start:value_end].decode('utf8') + ) + case other_type: + result = DataElement(other_type, self.data[value_start:value_end]) + + self.offset = value_end + result._bytes = self.data[start_offset:value_end] + + return result + + def _list_from_bytes(self, end_offset: int) -> list[DataElement]: + if self.depth >= self.max_depth: + raise InvalidPacketError( + f"SDP data element nesting exceeds max depth " f"({self.max_depth})" + ) + self.depth += 1 + elements = [] + while self.offset < end_offset: + elements.append(self.parse_next()) + self.depth -= 1 + return elements + + # ----------------------------------------------------------------------------- @dataclass class ServiceAttribute: