SDP: Move parser functions to parser class

This commit is contained in:
Josh Wu
2026-04-27 10:54:32 +08:00
parent 05accbf805
commit 0813da2278

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import struct import struct
import threading
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar
@@ -49,7 +48,6 @@ logger = logging.getLogger(__name__)
# prevent a malicious peer from crashing the process via a deeply nested PDU. # prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses. # 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32 _MAX_DATA_ELEMENT_NESTING = 32
_parse_state = threading.local()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -156,32 +154,6 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type type: Type
value: Any value: Any
value_size: int | None = None value_size: int | None = None
@@ -198,289 +170,354 @@ class DataElement:
'integer types must have a value size specified' 'integer types must have a value size specified'
) )
@staticmethod @classmethod
def nil() -> DataElement: def nil(cls) -> DataElement:
return DataElement(DataElement.NIL, None) return cls(cls.NIL, None)
@staticmethod @classmethod
def unsigned_integer(value: int, value_size: int) -> DataElement: def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size) return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod @classmethod
def unsigned_integer_8(value: int) -> DataElement: def unsigned_integer_8(cls, value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1) return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod @classmethod
def unsigned_integer_16(value: int) -> DataElement: def unsigned_integer_16(cls, value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2) return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod @classmethod
def unsigned_integer_32(value: int) -> DataElement: def unsigned_integer_32(cls, value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4) return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod @classmethod
def signed_integer(value: int, value_size: int) -> DataElement: def signed_integer(cls, value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size) return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod @classmethod
def signed_integer_8(value: int) -> DataElement: def signed_integer_8(cls, value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1) return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod @classmethod
def signed_integer_16(value: int) -> DataElement: def signed_integer_16(cls, value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2) return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod @classmethod
def signed_integer_32(value: int) -> DataElement: def signed_integer_32(cls, value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4) return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod @classmethod
def uuid(value: core.UUID) -> DataElement: def uuid(cls, value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value) return cls(cls.UUID, value)
@staticmethod @classmethod
def text_string(value: bytes) -> DataElement: def text_string(cls, value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value) return cls(cls.TEXT_STRING, value)
@staticmethod @classmethod
def boolean(value: bool) -> DataElement: def boolean(cls, value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value) return cls(cls.BOOLEAN, value)
@staticmethod @classmethod
def sequence(value: Iterable[DataElement]) -> DataElement: def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value) return cls(cls.SEQUENCE, value)
@staticmethod @classmethod
def alternative(value: Iterable[DataElement]) -> DataElement: def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value) return cls(cls.ALTERNATIVE, value)
@staticmethod @classmethod
def url(value: str) -> DataElement: def url(cls, value: str) -> DataElement:
return DataElement(DataElement.URL, value) return cls(cls.URL, value)
@staticmethod @classmethod
def unsigned_integer_from_bytes(data): def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
if len(data) == 1: match length:
return data[0] case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2: @classmethod
return struct.unpack('>H', data)[0] def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 4: @classmethod
return struct.unpack('>I', data)[0] def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 8: @classmethod
return struct.unpack('>Q', data)[0] def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
raise InvalidPacketError(f'invalid integer length {len(data)}') def __bytes__(self) -> bytes:
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
depth = getattr(_parse_state, "depth", 0)
if depth >= _MAX_DATA_ELEMENT_NESTING:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth "
f"({_MAX_DATA_ELEMENT_NESTING})"
)
_parse_state.depth = depth + 1
try:
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
finally:
_parse_state.depth = depth
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
# Return early if we have a cache # Return early if we have a cache
if self._bytes: if self._bytes:
return self._bytes return self._bytes
if self.type == DataElement.NIL: match self.type:
data = b'' case DataElement.NIL:
elif self.type == DataElement.UNSIGNED_INTEGER: data = b''
if self.value < 0: case DataElement.UNSIGNED_INTEGER:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative') if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1: match self.value_size:
data = struct.pack('B', self.value) case 1:
elif self.value_size == 2: data = struct.pack('B', self.value)
data = struct.pack('>H', self.value) case 2:
elif self.value_size == 4: data = struct.pack('>H', self.value)
data = struct.pack('>I', self.value) case 4:
elif self.value_size == 8: data = struct.pack('>I', self.value)
data = struct.pack('>Q', self.value) case 8:
else: data = struct.pack('>Q', self.value)
raise InvalidArgumentError('invalid value_size') case invalid_length:
elif self.type == DataElement.SIGNED_INTEGER: raise InvalidArgumentError(
if self.value_size == 1: f'invalid value_size of {invalid_length}'
data = struct.pack('b', self.value) )
elif self.value_size == 2: case DataElement.SIGNED_INTEGER:
data = struct.pack('>h', self.value) match self.value_size:
elif self.value_size == 4: case 1:
data = struct.pack('>i', self.value) data = struct.pack('b', self.value)
elif self.value_size == 8: case 2:
data = struct.pack('>q', self.value) data = struct.pack('>h', self.value)
else: case 4:
raise InvalidArgumentError('invalid value_size') data = struct.pack('>i', self.value)
elif self.type == DataElement.UUID: case 8:
data = bytes(reversed(bytes(self.value))) data = struct.pack('>q', self.value)
elif self.type == DataElement.URL: case invalid_length:
data = self.value.encode('utf8') raise InvalidArgumentError(
elif self.type == DataElement.BOOLEAN: f'invalid value_size of {invalid_length}'
data = bytes([1 if self.value else 0]) )
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): case DataElement.UUID:
data = b''.join([bytes(element) for element in self.value]) data = bytes(self.value)[::-1]
else: case DataElement.URL:
data = self.value data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
size = len(data) size = len(data)
size_bytes = b'' size_bytes = b''
if self.type == DataElement.NIL: match self.type:
if size != 0: case DataElement.NIL:
raise InvalidArgumentError('NIL must be empty') if size != 0:
size_index = 0 raise InvalidArgumentError('NIL must be empty')
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
size_index = 0 size_index = 0
elif size == 2: case (
size_index = 1 DataElement.UNSIGNED_INTEGER
elif size == 4: | DataElement.SIGNED_INTEGER
size_index = 2 | DataElement.UUID
elif size == 8: ):
size_index = 3 if size <= 1:
elif size == 16: size_index = 0
size_index = 4 elif size == 2:
else: size_index = 1
raise InvalidArgumentError('invalid data size') elif size == 4:
elif self.type in ( size_index = 2
DataElement.TEXT_STRING, elif size == 8:
DataElement.SEQUENCE, size_index = 3
DataElement.ALTERNATIVE, elif size == 16:
DataElement.URL, size_index = 4
): else:
if size <= 0xFF: raise InvalidArgumentError('invalid data size')
size_index = 5 case (
size_bytes = bytes([size]) DataElement.TEXT_STRING
elif size <= 0xFFFF: | DataElement.SEQUENCE
size_index = 6 | DataElement.ALTERNATIVE
size_bytes = struct.pack('>H', size) | DataElement.URL
elif size <= 0xFFFFFFFF: ):
size_index = 7 if size <= 0xFF:
size_bytes = struct.pack('>I', size) size_index = 5
else: size_bytes = bytes([size])
raise InvalidArgumentError('invalid data size') elif size <= 0xFFFF:
elif self.type == DataElement.BOOLEAN: size_index = 6
if size != 1: size_bytes = struct.pack('>H', size)
raise InvalidArgumentError('boolean must be 1 byte') elif size <= 0xFFFFFFFF:
size_index = 0 size_index = 7
else: size_bytes = struct.pack('>I', size)
raise RuntimeError("internal error - self.type not supported") else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes return self._bytes
def to_string(self, pretty=False, indentation=0): def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
prefix = ' ' * indentation prefix = ' ' * indentation
type_name = self.type.name type_name = self.type.name
if self.type == DataElement.NIL: match self.type:
value_string = '' case DataElement.NIL:
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): value_string = ''
container_separator = '\n' if pretty else '' case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
element_separator = '\n' if pretty else ',' container_separator = '\n' if pretty else ''
elements = [ element_separator = '\n' if pretty else ','
element.to_string(pretty, indentation + 1 if pretty else 0) elements = [
for element in self.value element.to_string(pretty, indentation + 1 if pretty else 0)
] for element in self.value
value_string = ( ]
f'[{container_separator}' value_string = (
f'{element_separator.join(elements)}' f'[{container_separator}'
f'{container_separator}{prefix}]' f'{element_separator.join(elements)}'
) f'{container_separator}{prefix}]'
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): )
value_string = f'{self.value}#{self.value_size}' case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
elif isinstance(self.value, DataElement): value_string = f'{self.value}#{self.value_size}'
value_string = self.value.to_string(pretty, indentation) case _:
else: if isinstance(self.value, DataElement):
value_string = str(self.value) value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})' return f'{prefix}{type_name}({value_string})'
def __str__(self): def __str__(self) -> str:
return self.to_string() return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass @dataclass
class ServiceAttribute: class ServiceAttribute: