diff --git a/bumble/sdp.py b/bumble/sdp.py index afabd9a..0562ebd 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -21,11 +21,12 @@ import asyncio import logging import struct from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, NewType +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NewType from typing_extensions import Self -from bumble import core, l2cap +from bumble import core, l2cap, utils from bumble.colors import color from bumble.core import ( InvalidArgumentError, @@ -141,30 +142,31 @@ SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF) # ----------------------------------------------------------------------------- +@dataclass class DataElement: - NIL = 0 - UNSIGNED_INTEGER = 1 - SIGNED_INTEGER = 2 - UUID = 3 - TEXT_STRING = 4 - BOOLEAN = 5 - SEQUENCE = 6 - ALTERNATIVE = 7 - URL = 8 - TYPE_NAMES = { - NIL: 'NIL', - UNSIGNED_INTEGER: 'UNSIGNED_INTEGER', - SIGNED_INTEGER: 'SIGNED_INTEGER', - UUID: 'UUID', - TEXT_STRING: 'TEXT_STRING', - BOOLEAN: 'BOOLEAN', - SEQUENCE: 'SEQUENCE', - ALTERNATIVE: 'ALTERNATIVE', - URL: 'URL', - } + class Type(utils.OpenIntEnum): + NIL = 0 + UNSIGNED_INTEGER = 1 + SIGNED_INTEGER = 2 + UUID = 3 + TEXT_STRING = 4 + BOOLEAN = 5 + SEQUENCE = 6 + ALTERNATIVE = 7 + URL = 8 - type_constructors = { + NIL = Type.NIL + UNSIGNED_INTEGER = Type.UNSIGNED_INTEGER + SIGNED_INTEGER = Type.SIGNED_INTEGER + UUID = Type.UUID + TEXT_STRING = Type.TEXT_STRING + BOOLEAN = Type.BOOLEAN + SEQUENCE = Type.SEQUENCE + ALTERNATIVE = Type.ALTERNATIVE + URL = Type.URL + + TYPE_CONSTRUCTORS = { NIL: lambda x: DataElement(DataElement.NIL, None), UNSIGNED_INTEGER: lambda x, y: DataElement( DataElement.UNSIGNED_INTEGER, @@ -190,14 +192,18 @@ class DataElement: URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')), } - def __init__(self, element_type, value, value_size=None): - self.type = element_type - self.value = value - self.value_size = value_size + type: Type + value: Any + value_size: int | None = None + + def __post_init__(self) -> None: # Used as a cache when parsing from bytes so we can emit a byte-for-byte replica - self.bytes = None - if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): - if value_size is None: + self._bytes: bytes | None = None + if self.type in ( + DataElement.UNSIGNED_INTEGER, + DataElement.SIGNED_INTEGER, + ): + if self.value_size is None: raise InvalidArgumentError( 'integer types must have a value size specified' ) @@ -337,7 +343,7 @@ class DataElement: value_offset = 4 value_data = data[1 + value_offset : 1 + value_offset + value_size] - constructor = DataElement.type_constructors.get(element_type) + constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type) if constructor: if element_type in ( DataElement.UNSIGNED_INTEGER, @@ -348,15 +354,15 @@ class DataElement: result = constructor(value_data) else: result = DataElement(element_type, value_data) - result.bytes = data[ + result._bytes = data[ : 1 + value_offset + value_size ] # Keep a copy so we can re-serialize to an exact replica return result def __bytes__(self): # Return early if we have a cache - if self.bytes: - return self.bytes + if self._bytes: + return self._bytes if self.type == DataElement.NIL: data = b'' @@ -443,12 +449,12 @@ class DataElement: else: raise RuntimeError("internal error - self.type not supported") - self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data - return self.bytes + self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data + return self._bytes def to_string(self, pretty=False, indentation=0): prefix = ' ' * indentation - type_name = name_or_number(self.TYPE_NAMES, self.type) + type_name = self.type.name if self.type == DataElement.NIL: value_string = '' elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): @@ -476,10 +482,10 @@ class DataElement: # ----------------------------------------------------------------------------- +@dataclass class ServiceAttribute: - def __init__(self, attribute_id: int, value: DataElement) -> None: - self.id = attribute_id - self.value = value + id: int + value: DataElement @staticmethod def list_from_data_elements(