SDP: Move parser functions to parser class

This commit is contained in:
Josh Wu
2026-04-27 10:54:32 +08:00
parent 05accbf805
commit 0813da2278

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import asyncio
import logging
import struct
import threading
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar
@@ -49,7 +48,6 @@ logger = logging.getLogger(__name__)
# prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32
_parse_state = threading.local()
# -----------------------------------------------------------------------------
@@ -156,32 +154,6 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
@@ -198,289 +170,354 @@ class DataElement:
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@classmethod
def nil(cls) -> DataElement:
return cls(cls.NIL, None)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@classmethod
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@classmethod
def unsigned_integer_8(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@classmethod
def unsigned_integer_16(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@classmethod
def unsigned_integer_32(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@classmethod
def signed_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@classmethod
def signed_integer_8(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@classmethod
def signed_integer_16(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@classmethod
def signed_integer_32(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@classmethod
def uuid(cls, value: core.UUID) -> DataElement:
return cls(cls.UUID, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@classmethod
def text_string(cls, value: bytes) -> DataElement:
return cls(cls.TEXT_STRING, value)
@staticmethod
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@classmethod
def boolean(cls, value: bool) -> DataElement:
return cls(cls.BOOLEAN, value)
@staticmethod
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@classmethod
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.SEQUENCE, value)
@staticmethod
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@classmethod
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.ALTERNATIVE, value)
@staticmethod
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@classmethod
def url(cls, value: str) -> DataElement:
return cls(cls.URL, value)
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
@classmethod
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2:
return struct.unpack('>H', data)[0]
@classmethod
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 4:
return struct.unpack('>I', data)[0]
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 8:
return struct.unpack('>Q', data)[0]
@classmethod
def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
depth = getattr(_parse_state, "depth", 0)
if depth >= _MAX_DATA_ELEMENT_NESTING:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth "
f"({_MAX_DATA_ELEMENT_NESTING})"
)
_parse_state.depth = depth + 1
try:
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
finally:
_parse_state.depth = depth
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
def __bytes__(self) -> bytes:
# Return early if we have a cache
if self._bytes:
return self._bytes
if self.type == DataElement.NIL:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
match self.type:
case DataElement.NIL:
data = b''
case DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
elif self.value_size == 4:
data = struct.pack('>I', self.value)
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
elif self.value_size == 2:
data = struct.pack('>h', self.value)
elif self.value_size == 4:
data = struct.pack('>i', self.value)
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
match self.value_size:
case 1:
data = struct.pack('B', self.value)
case 2:
data = struct.pack('>H', self.value)
case 4:
data = struct.pack('>I', self.value)
case 8:
data = struct.pack('>Q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.SIGNED_INTEGER:
match self.value_size:
case 1:
data = struct.pack('b', self.value)
case 2:
data = struct.pack('>h', self.value)
case 4:
data = struct.pack('>i', self.value)
case 8:
data = struct.pack('>q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.UUID:
data = bytes(self.value)[::-1]
case DataElement.URL:
data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
size = len(data)
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
match self.type:
case DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
case (
DataElement.UNSIGNED_INTEGER
| DataElement.SIGNED_INTEGER
| DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
case (
DataElement.TEXT_STRING
| DataElement.SEQUENCE
| DataElement.ALTERNATIVE
| DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty=False, indentation=0):
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
prefix = ' ' * indentation
type_name = self.type.name
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
match self.type:
case DataElement.NIL:
value_string = ''
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
case _:
if isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})'
def __str__(self):
def __str__(self) -> str:
return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute: