From 17563e423a6674bb91dea821eb2bebe2f4904eb9 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 2 Jul 2025 13:46:40 +0800 Subject: [PATCH] Migrate ATT PDU to dataclasses --- bumble/att.py | 489 +++++++++++++++++++++--------------------- bumble/gatt_client.py | 138 ++++++------ bumble/gatt_server.py | 223 +++++++++---------- tests/gatt_test.py | 4 +- 4 files changed, 423 insertions(+), 431 deletions(-) diff --git a/bumble/att.py b/bumble/att.py index 88846e2..d62a645 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -24,6 +24,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations +import dataclasses import enum import functools import inspect @@ -34,13 +35,16 @@ from typing import ( Generic, TypeVar, Union, + ClassVar, + Optional, TYPE_CHECKING, ) +from bumble import hci from bumble import utils -from bumble.core import UUID, name_or_number, InvalidOperationError, ProtocolError -from bumble.hci import HCI_Object, key_with_value +from bumble.core import UUID, InvalidOperationError, ProtocolError +from bumble.hci import HCI_Object from bumble.colors import color # ----------------------------------------------------------------------------- @@ -60,96 +64,66 @@ _T = TypeVar('_T') ATT_CID = 0x04 ATT_PSM = 0x001F -ATT_ERROR_RESPONSE = 0x01 -ATT_EXCHANGE_MTU_REQUEST = 0x02 -ATT_EXCHANGE_MTU_RESPONSE = 0x03 -ATT_FIND_INFORMATION_REQUEST = 0x04 -ATT_FIND_INFORMATION_RESPONSE = 0x05 -ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06 -ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07 -ATT_READ_BY_TYPE_REQUEST = 0x08 -ATT_READ_BY_TYPE_RESPONSE = 0x09 -ATT_READ_REQUEST = 0x0A -ATT_READ_RESPONSE = 0x0B -ATT_READ_BLOB_REQUEST = 0x0C -ATT_READ_BLOB_RESPONSE = 0x0D -ATT_READ_MULTIPLE_REQUEST = 0x0E -ATT_READ_MULTIPLE_RESPONSE = 0x0F -ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10 -ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11 -ATT_WRITE_REQUEST = 0x12 -ATT_WRITE_RESPONSE = 0x13 -ATT_WRITE_COMMAND = 0x52 -ATT_SIGNED_WRITE_COMMAND = 0xD2 -ATT_PREPARE_WRITE_REQUEST = 0x16 -ATT_PREPARE_WRITE_RESPONSE = 0x17 -ATT_EXECUTE_WRITE_REQUEST = 0x18 -ATT_EXECUTE_WRITE_RESPONSE = 0x19 -ATT_HANDLE_VALUE_NOTIFICATION = 0x1B -ATT_HANDLE_VALUE_INDICATION = 0x1D -ATT_HANDLE_VALUE_CONFIRMATION = 0x1E - -ATT_PDU_NAMES = { - ATT_ERROR_RESPONSE: 'ATT_ERROR_RESPONSE', - ATT_EXCHANGE_MTU_REQUEST: 'ATT_EXCHANGE_MTU_REQUEST', - ATT_EXCHANGE_MTU_RESPONSE: 'ATT_EXCHANGE_MTU_RESPONSE', - ATT_FIND_INFORMATION_REQUEST: 'ATT_FIND_INFORMATION_REQUEST', - ATT_FIND_INFORMATION_RESPONSE: 'ATT_FIND_INFORMATION_RESPONSE', - ATT_FIND_BY_TYPE_VALUE_REQUEST: 'ATT_FIND_BY_TYPE_VALUE_REQUEST', - ATT_FIND_BY_TYPE_VALUE_RESPONSE: 'ATT_FIND_BY_TYPE_VALUE_RESPONSE', - ATT_READ_BY_TYPE_REQUEST: 'ATT_READ_BY_TYPE_REQUEST', - ATT_READ_BY_TYPE_RESPONSE: 'ATT_READ_BY_TYPE_RESPONSE', - ATT_READ_REQUEST: 'ATT_READ_REQUEST', - ATT_READ_RESPONSE: 'ATT_READ_RESPONSE', - ATT_READ_BLOB_REQUEST: 'ATT_READ_BLOB_REQUEST', - ATT_READ_BLOB_RESPONSE: 'ATT_READ_BLOB_RESPONSE', - ATT_READ_MULTIPLE_REQUEST: 'ATT_READ_MULTIPLE_REQUEST', - ATT_READ_MULTIPLE_RESPONSE: 'ATT_READ_MULTIPLE_RESPONSE', - ATT_READ_BY_GROUP_TYPE_REQUEST: 'ATT_READ_BY_GROUP_TYPE_REQUEST', - ATT_READ_BY_GROUP_TYPE_RESPONSE: 'ATT_READ_BY_GROUP_TYPE_RESPONSE', - ATT_WRITE_REQUEST: 'ATT_WRITE_REQUEST', - ATT_WRITE_RESPONSE: 'ATT_WRITE_RESPONSE', - ATT_WRITE_COMMAND: 'ATT_WRITE_COMMAND', - ATT_SIGNED_WRITE_COMMAND: 'ATT_SIGNED_WRITE_COMMAND', - ATT_PREPARE_WRITE_REQUEST: 'ATT_PREPARE_WRITE_REQUEST', - ATT_PREPARE_WRITE_RESPONSE: 'ATT_PREPARE_WRITE_RESPONSE', - ATT_EXECUTE_WRITE_REQUEST: 'ATT_EXECUTE_WRITE_REQUEST', - ATT_EXECUTE_WRITE_RESPONSE: 'ATT_EXECUTE_WRITE_RESPONSE', - ATT_HANDLE_VALUE_NOTIFICATION: 'ATT_HANDLE_VALUE_NOTIFICATION', - ATT_HANDLE_VALUE_INDICATION: 'ATT_HANDLE_VALUE_INDICATION', - ATT_HANDLE_VALUE_CONFIRMATION: 'ATT_HANDLE_VALUE_CONFIRMATION' -} +class Opcode(hci.SpecableEnum): + ATT_ERROR_RESPONSE = 0x01 + ATT_EXCHANGE_MTU_REQUEST = 0x02 + ATT_EXCHANGE_MTU_RESPONSE = 0x03 + ATT_FIND_INFORMATION_REQUEST = 0x04 + ATT_FIND_INFORMATION_RESPONSE = 0x05 + ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06 + ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07 + ATT_READ_BY_TYPE_REQUEST = 0x08 + ATT_READ_BY_TYPE_RESPONSE = 0x09 + ATT_READ_REQUEST = 0x0A + ATT_READ_RESPONSE = 0x0B + ATT_READ_BLOB_REQUEST = 0x0C + ATT_READ_BLOB_RESPONSE = 0x0D + ATT_READ_MULTIPLE_REQUEST = 0x0E + ATT_READ_MULTIPLE_RESPONSE = 0x0F + ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10 + ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11 + ATT_WRITE_REQUEST = 0x12 + ATT_WRITE_RESPONSE = 0x13 + ATT_WRITE_COMMAND = 0x52 + ATT_SIGNED_WRITE_COMMAND = 0xD2 + ATT_PREPARE_WRITE_REQUEST = 0x16 + ATT_PREPARE_WRITE_RESPONSE = 0x17 + ATT_EXECUTE_WRITE_REQUEST = 0x18 + ATT_EXECUTE_WRITE_RESPONSE = 0x19 + ATT_HANDLE_VALUE_NOTIFICATION = 0x1B + ATT_HANDLE_VALUE_INDICATION = 0x1D + ATT_HANDLE_VALUE_CONFIRMATION = 0x1E ATT_REQUESTS = [ - ATT_EXCHANGE_MTU_REQUEST, - ATT_FIND_INFORMATION_REQUEST, - ATT_FIND_BY_TYPE_VALUE_REQUEST, - ATT_READ_BY_TYPE_REQUEST, - ATT_READ_REQUEST, - ATT_READ_BLOB_REQUEST, - ATT_READ_MULTIPLE_REQUEST, - ATT_READ_BY_GROUP_TYPE_REQUEST, - ATT_WRITE_REQUEST, - ATT_PREPARE_WRITE_REQUEST, - ATT_EXECUTE_WRITE_REQUEST + Opcode.ATT_EXCHANGE_MTU_REQUEST, + Opcode.ATT_FIND_INFORMATION_REQUEST, + Opcode.ATT_FIND_BY_TYPE_VALUE_REQUEST, + Opcode.ATT_READ_BY_TYPE_REQUEST, + Opcode.ATT_READ_REQUEST, + Opcode.ATT_READ_BLOB_REQUEST, + Opcode.ATT_READ_MULTIPLE_REQUEST, + Opcode.ATT_READ_BY_GROUP_TYPE_REQUEST, + Opcode.ATT_WRITE_REQUEST, + Opcode.ATT_PREPARE_WRITE_REQUEST, + Opcode.ATT_EXECUTE_WRITE_REQUEST ] ATT_RESPONSES = [ - ATT_ERROR_RESPONSE, - ATT_EXCHANGE_MTU_RESPONSE, - ATT_FIND_INFORMATION_RESPONSE, - ATT_FIND_BY_TYPE_VALUE_RESPONSE, - ATT_READ_BY_TYPE_RESPONSE, - ATT_READ_RESPONSE, - ATT_READ_BLOB_RESPONSE, - ATT_READ_MULTIPLE_RESPONSE, - ATT_READ_BY_GROUP_TYPE_RESPONSE, - ATT_WRITE_RESPONSE, - ATT_PREPARE_WRITE_RESPONSE, - ATT_EXECUTE_WRITE_RESPONSE + Opcode.ATT_ERROR_RESPONSE, + Opcode.ATT_EXCHANGE_MTU_RESPONSE, + Opcode.ATT_FIND_INFORMATION_RESPONSE, + Opcode.ATT_FIND_BY_TYPE_VALUE_RESPONSE, + Opcode.ATT_READ_BY_TYPE_RESPONSE, + Opcode.ATT_READ_RESPONSE, + Opcode.ATT_READ_BLOB_RESPONSE, + Opcode.ATT_READ_MULTIPLE_RESPONSE, + Opcode.ATT_READ_BY_GROUP_TYPE_RESPONSE, + Opcode.ATT_WRITE_RESPONSE, + Opcode.ATT_PREPARE_WRITE_RESPONSE, + Opcode.ATT_EXECUTE_WRITE_RESPONSE ] -class ErrorCode(utils.OpenIntEnum): +class ErrorCode(hci.SpecableEnum): ''' See @@ -204,10 +178,6 @@ ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES ATT_DEFAULT_MTU = 23 HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'} -# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda -UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) -# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda -UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 # fmt: on # pylint: enable=line-too-long @@ -227,7 +197,7 @@ class ATT_Error(ProtocolError): super().__init__( error_code, error_namespace='att', - error_name=ATT_PDU.error_name(error_code), + error_name=ErrorCode(error_code).name, ) self.att_handle = att_handle self.message = message @@ -242,61 +212,45 @@ class ATT_Error(ProtocolError): # ----------------------------------------------------------------------------- # Attribute Protocol # ----------------------------------------------------------------------------- +@dataclasses.dataclass class ATT_PDU: ''' See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU ''' - pdu_classes: dict[int, type[ATT_PDU]] = {} - op_code = 0 - name: str - - @staticmethod - def from_bytes(pdu): - op_code = pdu[0] - - cls = ATT_PDU.pdu_classes.get(op_code) - if cls is None: - instance = ATT_PDU(pdu) - instance.name = ATT_PDU.pdu_name(op_code) - instance.op_code = op_code - return instance - self = cls.__new__(cls) - ATT_PDU.__init__(self, pdu) - if hasattr(self, 'fields'): - self.init_from_bytes(pdu, 1) - return self - - @staticmethod - def pdu_name(op_code): - return name_or_number(ATT_PDU_NAMES, op_code, 2) + pdu_classes: ClassVar[dict[int, type[ATT_PDU]]] = {} + fields: ClassVar[hci.Fields] = () + op_code: int = dataclasses.field(init=False) + name: str = dataclasses.field(init=False) + _payload: Optional[bytes] = dataclasses.field(default=None, init=False) @classmethod - def error_name(cls, error_code: int) -> str: - return ErrorCode(error_code).name + def from_bytes(cls, pdu: bytes) -> ATT_PDU: + op_code = pdu[0] - @staticmethod - def subclass(fields): - def inner(cls): - cls.name = cls.__name__.upper() - cls.op_code = key_with_value(ATT_PDU_NAMES, cls.name) - if cls.op_code is None: - raise KeyError(f'PDU name {cls.name} not found in ATT_PDU_NAMES') - cls.fields = fields + subclass = ATT_PDU.pdu_classes.get(op_code) + if subclass is None: + instance = ATT_PDU() + instance.op_code = op_code + instance.payload = pdu[1:] + instance.name = Opcode(op_code).name + return instance + instance = subclass(**HCI_Object.dict_from_bytes(pdu, 1, subclass.fields)) + instance.payload = pdu[1:] + return instance - # Register a factory for this class - ATT_PDU.pdu_classes[cls.op_code] = cls + _PDU = TypeVar("_PDU", bound="ATT_PDU") - return cls + @classmethod + def subclass(cls, subclass: type[_PDU]) -> type[_PDU]: + subclass.name = subclass.__name__.upper() + subclass.op_code = Opcode[subclass.name] + subclass.fields = HCI_Object.fields_from_dataclass(subclass) - return inner + # Register a factory for this class + ATT_PDU.pdu_classes[subclass.op_code] = subclass - def __init__(self, pdu=None, **kwargs): - if hasattr(self, 'fields') and kwargs: - HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - pdu = bytes([self.op_code]) + HCI_Object.dict_to_bytes(kwargs, self.fields) - self.pdu = pdu + return subclass def init_from_bytes(self, pdu, offset): return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) @@ -309,67 +263,91 @@ class ATT_PDU: def has_authentication_signature(self): return ((self.op_code >> 7) & 1) == 1 - def __bytes__(self): - return self.pdu + @property + def payload(self) -> bytes: + if self._payload is None: + self._payload = HCI_Object.dict_to_bytes(self.__dict__, self.fields) + return self._payload + + @payload.setter + def payload(self, value: bytes): + self._payload = value + + def __bytes__(self) -> bytes: + return bytes([self.op_code]) + self.payload def __str__(self): result = color(self.name, 'yellow') if fields := getattr(self, 'fields', None): result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') else: - if len(self.pdu) > 1: - result += f': {self.pdu.hex()}' + if self.payload: + result += f': {self.payload.hex()}' return result # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), - ('attribute_handle_in_error', HANDLE_FIELD_SPEC), - ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Error_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response ''' + request_opcode_in_error: int = dataclasses.field(metadata=Opcode.type_metadata(1)) + attribute_handle_in_error: int = dataclasses.field( + metadata=hci.metadata(HANDLE_FIELD_SPEC) + ) + error_code: int = dataclasses.field(metadata=ErrorCode.type_metadata(1)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('client_rx_mtu', 2)]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Exchange_MTU_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request ''' + client_rx_mtu: int = dataclasses.field(metadata=hci.metadata(2)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('server_rx_mtu', 2)]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Exchange_MTU_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response ''' + server_rx_mtu: int = dataclasses.field(metadata=hci.metadata(2)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Find_Information_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request ''' + starting_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + ending_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('format', 1), ('information_data', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Find_Information_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response ''' - def parse_information_data(self): + format: int = dataclasses.field(metadata=hci.metadata(1)) + information_data: bytes = dataclasses.field(metadata=hci.metadata("*")) + information: list[tuple[int, bytes]] = dataclasses.field(init=False) + + def __post_init__(self) -> None: self.information = [] offset = 0 uuid_size = 2 if self.format == 1 else 16 @@ -379,14 +357,6 @@ class ATT_Find_Information_Response(ATT_PDU): self.information.append((handle, uuid)) offset += 2 + uuid_size - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.parse_information_data() - - def init_from_bytes(self, pdu, offset): - super().init_from_bytes(pdu, offset) - self.parse_information_data() - def __str__(self): result = color(self.name, 'yellow') result += ':\n' + HCI_Object.format_fields( @@ -408,28 +378,31 @@ class ATT_Find_Information_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('starting_handle', HANDLE_FIELD_SPEC), - ('ending_handle', HANDLE_FIELD_SPEC), - ('attribute_type', UUID_2_FIELD_SPEC), - ('attribute_value', '*'), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Find_By_Type_Value_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request ''' + starting_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + ending_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_type: UUID = dataclasses.field(metadata=hci.metadata(UUID.parse_uuid_2)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('handles_information_list', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Find_By_Type_Value_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response ''' - def parse_handles_information_list(self): + handles_information_list: bytes = dataclasses.field(metadata=hci.metadata("*")) + handles_information: list[tuple[int, int]] = dataclasses.field(init=False) + + def __post_init__(self) -> None: self.handles_information = [] offset = 0 while offset + 4 <= len(self.handles_information_list): @@ -439,14 +412,6 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU): self.handles_information.append((found_attribute_handle, group_end_handle)) offset += 4 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.parse_handles_information_list() - - def init_from_bytes(self, pdu, offset): - super().init_from_bytes(pdu, offset) - self.parse_handles_information_list() - def __str__(self): result = color(self.name, 'yellow') result += ':\n' + HCI_Object.format_fields( @@ -470,27 +435,31 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('starting_handle', HANDLE_FIELD_SPEC), - ('ending_handle', HANDLE_FIELD_SPEC), - ('attribute_type', UUID_2_16_FIELD_SPEC), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_By_Type_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request ''' + starting_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + ending_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_type: UUID = dataclasses.field(metadata=hci.metadata(UUID.parse_uuid)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_By_Type_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response ''' - def parse_attribute_data_list(self): + length: int = dataclasses.field(metadata=hci.metadata(1)) + attribute_data_list: bytes = dataclasses.field(metadata=hci.metadata("*")) + attributes: list[tuple[int, bytes]] = dataclasses.field(init=False) + + def __post_init__(self) -> None: self.attributes = [] offset = 0 while self.length != 0 and offset + self.length <= len( @@ -505,14 +474,6 @@ class ATT_Read_By_Type_Response(ATT_PDU): self.attributes.append((attribute_handle, attribute_value)) offset += self.length - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.parse_attribute_data_list() - - def init_from_bytes(self, pdu, offset): - super().init_from_bytes(pdu, offset) - self.parse_attribute_data_list() - def __str__(self): result = color(self.name, 'yellow') result += ':\n' + HCI_Object.format_fields( @@ -534,75 +495,100 @@ class ATT_Read_By_Type_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response ''' + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Blob_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + value_offset: int = dataclasses.field(metadata=hci.metadata(2)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('part_attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Blob_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response ''' + part_attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('set_of_handles', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Multiple_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request ''' + set_of_handles: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('set_of_values', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_Multiple_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response ''' + set_of_values: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('starting_handle', HANDLE_FIELD_SPEC), - ('ending_handle', HANDLE_FIELD_SPEC), - ('attribute_group_type', UUID_2_16_FIELD_SPEC), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_By_Group_Type_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request ''' + starting_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + ending_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_group_type: UUID = dataclasses.field( + metadata=hci.metadata(UUID.parse_uuid) + ) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Read_By_Group_Type_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response ''' - def parse_attribute_data_list(self): + length: int = dataclasses.field(metadata=hci.metadata(1)) + attribute_data_list: bytes = dataclasses.field(metadata=hci.metadata("*")) + attributes: list[tuple[int, int, bytes]] = dataclasses.field(init=False) + + def __post_init__(self) -> None: self.attributes = [] offset = 0 while self.length != 0 and offset + self.length <= len( @@ -619,14 +605,6 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU): ) offset += self.length - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.parse_attribute_data_list() - - def init_from_bytes(self, pdu, offset): - super().init_from_bytes(pdu, offset) - self.parse_attribute_data_list() - def __str__(self): result = color(self.name, 'yellow') result += ':\n' + HCI_Object.format_fields( @@ -651,15 +629,20 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Write_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Write_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.5.2 Write Response @@ -667,65 +650,70 @@ class ATT_Write_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Write_Command(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('attribute_handle', HANDLE_FIELD_SPEC), - ('attribute_value', '*'), - # ('authentication_signature', 'TODO') - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Signed_Write_Command(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # TODO: authentication_signature + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('attribute_handle', HANDLE_FIELD_SPEC), - ('value_offset', 2), - ('part_attribute_value', '*'), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Prepare_Write_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + value_offset: int = dataclasses.field(metadata=hci.metadata(2)) + part_attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass( - [ - ('attribute_handle', HANDLE_FIELD_SPEC), - ('value_offset', 2), - ('part_attribute_value', '*'), - ] -) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Prepare_Write_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + value_offset: int = dataclasses.field(metadata=hci.metadata(2)) + part_attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([("flags", 1)]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Execute_Write_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.6.3 Execute Write Request ''' + flags: int = dataclasses.field(metadata=hci.metadata(1)) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Execute_Write_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.6.4 Execute Write Response @@ -733,23 +721,32 @@ class ATT_Execute_Write_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Handle_Value_Notification(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Handle_Value_Indication(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication ''' + attribute_handle: int = dataclasses.field(metadata=hci.metadata(HANDLE_FIELD_SPEC)) + attribute_value: bytes = dataclasses.field(metadata=hci.metadata("*")) + # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([]) +@ATT_PDU.subclass +@dataclasses.dataclass class ATT_Handle_Value_Confirmation(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.7.3 Handle Value Confirmation diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 0a50edb..65e106f 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -42,27 +42,7 @@ from typing import ( from bumble.colors import color from bumble.hci import HCI_Constant -from bumble.att import ( - ATT_ATTRIBUTE_NOT_FOUND_ERROR, - ATT_ATTRIBUTE_NOT_LONG_ERROR, - ATT_CID, - ATT_DEFAULT_MTU, - ATT_ERROR_RESPONSE, - ATT_INVALID_OFFSET_ERROR, - ATT_PDU, - ATT_RESPONSES, - ATT_Exchange_MTU_Request, - ATT_Find_By_Type_Value_Request, - ATT_Find_Information_Request, - ATT_Handle_Value_Confirmation, - ATT_Read_Blob_Request, - ATT_Read_By_Group_Type_Request, - ATT_Read_By_Type_Request, - ATT_Read_Request, - ATT_Write_Command, - ATT_Write_Request, - ATT_Error, -) +from bumble import att from bumble import utils from bumble import core from bumble.core import UUID, InvalidStateError @@ -291,8 +271,8 @@ class Client: indication_subscribers: dict[ int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]] ] - pending_response: Optional[asyncio.futures.Future[ATT_PDU]] - pending_request: Optional[ATT_PDU] + pending_response: Optional[asyncio.futures.Future[att.ATT_PDU]] + pending_request: Optional[att.ATT_PDU] def __init__(self, connection: Connection) -> None: self.connection = connection @@ -308,15 +288,15 @@ class Client: connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection) def send_gatt_pdu(self, pdu: bytes) -> None: - self.connection.send_l2cap_pdu(ATT_CID, pdu) + self.connection.send_l2cap_pdu(att.ATT_CID, pdu) - async def send_command(self, command: ATT_PDU) -> None: + async def send_command(self, command: att.ATT_PDU) -> None: logger.debug( f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' ) self.send_gatt_pdu(bytes(command)) - async def send_request(self, request: ATT_PDU): + async def send_request(self, request: att.ATT_PDU): logger.debug( f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' ) @@ -345,7 +325,9 @@ class Client: return response - def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None: + def send_confirmation( + self, confirmation: att.ATT_Handle_Value_Confirmation + ) -> None: logger.debug( f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'{confirmation}' @@ -354,8 +336,8 @@ class Client: async def request_mtu(self, mtu: int) -> int: # Check the range - if mtu < ATT_DEFAULT_MTU: - raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}') + if mtu < att.ATT_DEFAULT_MTU: + raise core.InvalidArgumentError(f'MTU must be >= {att.ATT_DEFAULT_MTU}') if mtu > 0xFFFF: raise core.InvalidArgumentError('MTU must be <= 0xFFFF') @@ -365,9 +347,11 @@ class Client: # Send the request self.mtu_exchange_done = True - response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu)) - if response.op_code == ATT_ERROR_RESPONSE: - raise ATT_Error(error_code=response.error_code, message=response) + response = await self.send_request( + att.ATT_Exchange_MTU_Request(client_rx_mtu=mtu) + ) + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + raise att.ATT_Error(error_code=response.error_code, message=response) # Compute the final MTU self.connection.att_mtu = min(mtu, response.server_rx_mtu) @@ -432,7 +416,7 @@ class Client: services = [] while starting_handle < 0xFFFF: response = await self.send_request( - ATT_Read_By_Group_Type_Request( + att.ATT_Read_By_Group_Type_Request( starting_handle=starting_handle, ending_handle=0xFFFF, attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, @@ -443,14 +427,14 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering services: ' f'{HCI_Constant.error_name(response.error_code)}' ) - raise ATT_Error( + raise att.ATT_Error( error_code=response.error_code, message='Unexpected error while discovering services', ) @@ -509,7 +493,7 @@ class Client: services = [] while starting_handle < 0xFFFF: response = await self.send_request( - ATT_Find_By_Type_Value_Request( + att.ATT_Find_By_Type_Value_Request( starting_handle=starting_handle, ending_handle=0xFFFF, attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, @@ -521,8 +505,8 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering services: ' @@ -578,7 +562,7 @@ class Client: included_services: list[ServiceProxy] = [] while starting_handle <= ending_handle: response = await self.send_request( - ATT_Read_By_Type_Request( + att.ATT_Read_By_Type_Request( starting_handle=starting_handle, ending_handle=ending_handle, attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE, @@ -589,14 +573,14 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering included services: ' f'{HCI_Constant.error_name(response.error_code)}' ) - raise ATT_Error( + raise att.ATT_Error( error_code=response.error_code, message='Unexpected error while discovering included services', ) @@ -652,7 +636,7 @@ class Client: characteristics: list[CharacteristicProxy[bytes]] = [] while starting_handle <= ending_handle: response = await self.send_request( - ATT_Read_By_Type_Request( + att.ATT_Read_By_Type_Request( starting_handle=starting_handle, ending_handle=ending_handle, attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, @@ -663,14 +647,14 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering characteristics: ' f'{HCI_Constant.error_name(response.error_code)}' ) - raise ATT_Error( + raise att.ATT_Error( error_code=response.error_code, message='Unexpected error while discovering characteristics', ) @@ -736,7 +720,7 @@ class Client: descriptors: list[DescriptorProxy] = [] while starting_handle <= ending_handle: response = await self.send_request( - ATT_Find_Information_Request( + att.ATT_Find_Information_Request( starting_handle=starting_handle, ending_handle=ending_handle ) ) @@ -745,8 +729,8 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering descriptors: ' @@ -791,7 +775,7 @@ class Client: attributes = [] while True: response = await self.send_request( - ATT_Find_Information_Request( + att.ATT_Find_Information_Request( starting_handle=starting_handle, ending_handle=ending_handle ) ) @@ -799,8 +783,8 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while discovering attributes: ' @@ -954,12 +938,12 @@ class Client: # Send a request to read attribute_handle = attribute if isinstance(attribute, int) else attribute.handle response = await self.send_request( - ATT_Read_Request(attribute_handle=attribute_handle) + att.ATT_Read_Request(attribute_handle=attribute_handle) ) if response is None: raise TimeoutError('read timeout') - if response.op_code == ATT_ERROR_RESPONSE: - raise ATT_Error(error_code=response.error_code, message=response) + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + raise att.ATT_Error(error_code=response.error_code, message=response) # If the value is the max size for the MTU, try to read more unless the caller # specifically asked not to do that @@ -969,19 +953,21 @@ class Client: offset = len(attribute_value) while True: response = await self.send_request( - ATT_Read_Blob_Request( + att.ATT_Read_Blob_Request( attribute_handle=attribute_handle, value_offset=offset ) ) if response is None: raise TimeoutError('read timeout') - if response.op_code == ATT_ERROR_RESPONSE: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.error_code in ( - ATT_ATTRIBUTE_NOT_LONG_ERROR, - ATT_INVALID_OFFSET_ERROR, + att.ATT_ATTRIBUTE_NOT_LONG_ERROR, + att.ATT_INVALID_OFFSET_ERROR, ): break - raise ATT_Error(error_code=response.error_code, message=response) + raise att.ATT_Error( + error_code=response.error_code, message=response + ) part = response.part_attribute_value attribute_value += part @@ -1012,7 +998,7 @@ class Client: characteristics_values = [] while starting_handle <= ending_handle: response = await self.send_request( - ATT_Read_By_Type_Request( + att.ATT_Read_By_Type_Request( starting_handle=starting_handle, ending_handle=ending_handle, attribute_type=uuid, @@ -1023,8 +1009,8 @@ class Client: return [] # Check if we reached the end of the iteration - if response.op_code == ATT_ERROR_RESPONSE: - if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end logger.warning( '!!! unexpected error while reading characteristics: ' @@ -1069,15 +1055,15 @@ class Client: attribute_handle = attribute if isinstance(attribute, int) else attribute.handle if with_response: response = await self.send_request( - ATT_Write_Request( + att.ATT_Write_Request( attribute_handle=attribute_handle, attribute_value=value ) ) - if response.op_code == ATT_ERROR_RESPONSE: - raise ATT_Error(error_code=response.error_code, message=response) + if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: + raise att.ATT_Error(error_code=response.error_code, message=response) else: await self.send_command( - ATT_Write_Command( + att.ATT_Write_Command( attribute_handle=attribute_handle, attribute_value=value ) ) @@ -1086,11 +1072,11 @@ class Client: if self.pending_response and not self.pending_response.done(): self.pending_response.cancel() - def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: + def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' ) - if att_pdu.op_code in ATT_RESPONSES: + if att_pdu.op_code in att.ATT_RESPONSES: if self.pending_request is None: # Not expected! logger.warning('!!! unexpected response, there is no pending request') @@ -1098,7 +1084,7 @@ class Client: # The response should match the pending request unless it is # an error response - if att_pdu.op_code != ATT_ERROR_RESPONSE: + if att_pdu.op_code != att.Opcode.ATT_ERROR_RESPONSE: expected_response_name = self.pending_request.name.replace( '_REQUEST', '_RESPONSE' ) @@ -1126,7 +1112,9 @@ class Client: + str(att_pdu) ) - def on_att_handle_value_notification(self, notification): + def on_att_handle_value_notification( + self, notification: att.ATT_Handle_Value_Notification + ): # Call all subscribers subscribers = self.notification_subscribers.get( notification.attribute_handle, set() @@ -1141,7 +1129,9 @@ class Client: else: subscriber.emit(subscriber.EVENT_UPDATE, notification.attribute_value) - def on_att_handle_value_indication(self, indication): + def on_att_handle_value_indication( + self, indication: att.ATT_Handle_Value_Indication + ): # Call all subscribers subscribers = self.indication_subscribers.get( indication.attribute_handle, set() @@ -1157,7 +1147,7 @@ class Client: subscriber.emit(subscriber.EVENT_UPDATE, indication.attribute_value) # Confirm that we received the indication - self.send_confirmation(ATT_Handle_Value_Confirmation()) + self.send_confirmation(att.ATT_Handle_Value_Confirmation()) def cache_value(self, attribute_handle: int, value: bytes) -> None: self.cached_values[attribute_handle] = ( diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 329cd55..35bf8c3 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -13,7 +13,7 @@ # limitations under the License. # ----------------------------------------------------------------------------- -# GATT - Generic Attribute Profile +# GATT - Generic att.Attribute Profile # Server # # See Bluetooth spec @ Vol 3, Part G @@ -35,35 +35,9 @@ from typing import ( TYPE_CHECKING, ) +from bumble import att from bumble.colors import color from bumble.core import UUID -from bumble.att import ( - ATT_ATTRIBUTE_NOT_FOUND_ERROR, - ATT_ATTRIBUTE_NOT_LONG_ERROR, - ATT_CID, - ATT_DEFAULT_MTU, - ATT_INVALID_ATTRIBUTE_LENGTH_ERROR, - ATT_INVALID_HANDLE_ERROR, - ATT_INVALID_OFFSET_ERROR, - ATT_REQUEST_NOT_SUPPORTED_ERROR, - ATT_REQUESTS, - ATT_PDU, - ATT_UNLIKELY_ERROR_ERROR, - ATT_UNSUPPORTED_GROUP_TYPE_ERROR, - ATT_Error, - ATT_Error_Response, - ATT_Exchange_MTU_Response, - ATT_Find_By_Type_Value_Response, - ATT_Find_Information_Response, - ATT_Handle_Value_Indication, - ATT_Handle_Value_Notification, - ATT_Read_Blob_Response, - ATT_Read_By_Group_Type_Response, - ATT_Read_By_Type_Response, - ATT_Read_Response, - ATT_Write_Response, - Attribute, -) from bumble.gatt import ( GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, @@ -99,9 +73,9 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517 # GATT Server # ----------------------------------------------------------------------------- class Server(utils.EventEmitter): - attributes: list[Attribute] + attributes: list[att.Attribute] services: list[Service] - attributes_by_handle: dict[int, Attribute] + attributes_by_handle: dict[int, att.Attribute] subscribers: dict[int, dict[int, bytes]] indication_semaphores: defaultdict[int, asyncio.Semaphore] pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]] @@ -112,7 +86,7 @@ class Server(utils.EventEmitter): super().__init__() self.device = device self.services = [] - self.attributes = [] # Attributes, ordered by increasing handle values + self.attributes = [] # att.Attributes, ordered by increasing handle values self.attributes_by_handle = {} # Map for fast attribute access by handle self.max_mtu = ( GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate @@ -127,12 +101,12 @@ class Server(utils.EventEmitter): return "\n".join(map(str, self.attributes)) def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None: - self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) + self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu) def next_handle(self) -> int: return 1 + len(self.attributes) - def get_advertising_service_data(self) -> dict[Attribute, bytes]: + def get_advertising_service_data(self) -> dict[att.Attribute, bytes]: return { attribute: data for attribute in self.attributes @@ -140,7 +114,7 @@ class Server(utils.EventEmitter): and (data := attribute.get_advertising_data()) } - def get_attribute(self, handle: int) -> Optional[Attribute]: + def get_attribute(self, handle: int) -> Optional[att.Attribute]: attribute = self.attributes_by_handle.get(handle) if attribute: return attribute @@ -231,7 +205,7 @@ class Server(utils.EventEmitter): None, ) - def add_attribute(self, attribute: Attribute) -> None: + def add_attribute(self, attribute: att.Attribute) -> None: # Assign a handle to this attribute attribute.handle = self.next_handle() attribute.end_group_handle = ( @@ -286,7 +260,7 @@ class Server(utils.EventEmitter): # pylint: disable=line-too-long Descriptor( GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, - Attribute.READABLE | Attribute.WRITEABLE, + att.Attribute.READABLE | att.Attribute.WRITEABLE, CharacteristicValue( read=lambda connection, characteristic=characteristic: self.read_cccd( connection, characteristic @@ -355,7 +329,7 @@ class Server(utils.EventEmitter): indicate_enabled, ) - def send_response(self, connection: Connection, response: ATT_PDU) -> None: + def send_response(self, connection: Connection, response: att.ATT_PDU) -> None: logger.debug( f'GATT Response from server: [0x{connection.handle:04X}] {response}' ) @@ -364,7 +338,7 @@ class Server(utils.EventEmitter): async def notify_subscriber( self, connection: Connection, - attribute: Attribute, + attribute: att.Attribute, value: Optional[bytes] = None, force: bool = False, ) -> None: @@ -396,7 +370,7 @@ class Server(utils.EventEmitter): value = value[: connection.att_mtu - 3] # Notify - notification = ATT_Handle_Value_Notification( + notification = att.ATT_Handle_Value_Notification( attribute_handle=attribute.handle, attribute_value=value ) logger.debug( @@ -407,7 +381,7 @@ class Server(utils.EventEmitter): async def indicate_subscriber( self, connection: Connection, - attribute: Attribute, + attribute: att.Attribute, value: Optional[bytes] = None, force: bool = False, ) -> None: @@ -439,7 +413,7 @@ class Server(utils.EventEmitter): value = value[: connection.att_mtu - 3] # Indicate - indication = ATT_Handle_Value_Indication( + indication = att.ATT_Handle_Value_Indication( attribute_handle=attribute.handle, attribute_value=value ) logger.debug( @@ -467,7 +441,7 @@ class Server(utils.EventEmitter): async def _notify_or_indicate_subscribers( self, indicate: bool, - attribute: Attribute, + attribute: att.Attribute, value: Optional[bytes] = None, force: bool = False, ) -> None: @@ -494,7 +468,7 @@ class Server(utils.EventEmitter): async def notify_subscribers( self, - attribute: Attribute, + attribute: att.Attribute, value: Optional[bytes] = None, force: bool = False, ): @@ -504,7 +478,7 @@ class Server(utils.EventEmitter): async def indicate_subscribers( self, - attribute: Attribute, + attribute: att.Attribute, value: Optional[bytes] = None, force: bool = False, ): @@ -518,16 +492,16 @@ class Server(utils.EventEmitter): if connection.handle in self.pending_confirmations: del self.pending_confirmations[connection.handle] - def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None: + def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') handler_name = f'on_{att_pdu.name.lower()}' handler = getattr(self, handler_name, None) if handler is not None: try: handler(connection, att_pdu) - except ATT_Error as error: + except att.ATT_Error as error: logger.debug(f'normal exception returned by handler: {error}') - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=att_pdu.op_code, attribute_handle_in_error=error.att_handle, error_code=error.error_code, @@ -535,16 +509,16 @@ class Server(utils.EventEmitter): self.send_response(connection, response) except Exception as error: logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=att_pdu.op_code, attribute_handle_in_error=0x0000, - error_code=ATT_UNLIKELY_ERROR_ERROR, + error_code=att.ATT_UNLIKELY_ERROR_ERROR, ) self.send_response(connection, response) raise error else: # No specific handler registered - if att_pdu.op_code in ATT_REQUESTS: + if att_pdu.op_code in att.ATT_REQUESTS: # Invoke the generic handler self.on_att_request(connection, att_pdu) else: @@ -560,7 +534,7 @@ class Server(utils.EventEmitter): ####################################################### # ATT handlers ####################################################### - def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None: + def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None: ''' Handler for requests without a more specific handler ''' @@ -570,23 +544,25 @@ class Server(utils.EventEmitter): ) + str(pdu) ) - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=pdu.op_code, attribute_handle_in_error=0x0000, - error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR, + error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR, ) self.send_response(connection, response) - def on_att_exchange_mtu_request(self, connection, request): + def on_att_exchange_mtu_request( + self, connection: Connection, request: att.ATT_Exchange_MTU_Request + ): ''' See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request ''' self.send_response( - connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) + connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) ) # Compute the final MTU - if request.client_rx_mtu >= ATT_DEFAULT_MTU: + if request.client_rx_mtu >= att.ATT_DEFAULT_MTU: mtu = min(self.max_mtu, request.client_rx_mtu) # Notify the device @@ -594,11 +570,14 @@ class Server(utils.EventEmitter): else: logger.warning('invalid client_rx_mtu received, MTU not changed') - def on_att_find_information_request(self, connection, request): + def on_att_find_information_request( + self, connection: Connection, request: att.ATT_Find_Information_Request + ): ''' See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request ''' + response: att.ATT_PDU # Check the request parameters if ( request.starting_handle == 0 @@ -606,17 +585,17 @@ class Server(utils.EventEmitter): ): self.send_response( connection, - ATT_Error_Response( + att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.starting_handle, - error_code=ATT_INVALID_HANDLE_ERROR, + error_code=att.ATT_INVALID_HANDLE_ERROR, ), ) return # Build list of returned attributes pdu_space_available = connection.att_mtu - 2 - attributes = [] + attributes: list[att.Attribute] = [] uuid_size = 0 for attribute in ( attribute @@ -646,21 +625,23 @@ class Server(utils.EventEmitter): struct.pack(' len(value): - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, - error_code=ATT_INVALID_OFFSET_ERROR, + error_code=att.ATT_INVALID_OFFSET_ERROR, ) elif len(value) <= connection.att_mtu - 1: - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, - error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR, + error_code=att.ATT_ATTRIBUTE_NOT_LONG_ERROR, ) else: part_size = min( connection.att_mtu - 1, len(value) - request.value_offset ) - response = ATT_Read_Blob_Response( + response = att.ATT_Read_Blob_Response( part_attribute_value=value[ request.value_offset : request.value_offset + part_size ] ) else: - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, - error_code=ATT_INVALID_HANDLE_ERROR, + error_code=att.ATT_INVALID_HANDLE_ERROR, ) self.send_response(connection, response) @utils.AsyncRunner.run_in_task() - async def on_att_read_by_group_type_request(self, connection, request): + async def on_att_read_by_group_type_request( + self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request + ): ''' See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request ''' + response: att.ATT_PDU if request.attribute_group_type not in ( GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, ): - response = ATT_Error_Response( + response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.starting_handle, - error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR, + error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ) self.send_response(connection, response) return pdu_space_available = connection.att_mtu - 2 - attributes = [] + attributes: list[tuple[int, int, bytes]] = [] for attribute in ( attribute for attribute in self.attributes @@ -904,21 +897,23 @@ class Server(utils.EventEmitter): struct.pack(' GATT_MAX_ATTRIBUTE_VALUE_SIZE: self.send_response( connection, - ATT_Error_Response( + att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, - error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR, + error_code=att.ATT_INVALID_ATTRIBUTE_LENGTH_ERROR, ), ) return + response: att.ATT_PDU try: # Accept the value await attribute.write_value(connection, request.attribute_value) - except ATT_Error as error: - response = ATT_Error_Response( + except att.ATT_Error as error: + response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, error_code=error.error_code, ) else: # Done - response = ATT_Write_Response() + response = att.ATT_Write_Response() self.send_response(connection, response) @utils.AsyncRunner.run_in_task() - async def on_att_write_command(self, connection, request): + async def on_att_write_command( + self, connection: Connection, request: att.ATT_Write_Command + ): ''' See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command ''' @@ -987,15 +985,22 @@ class Server(utils.EventEmitter): except Exception as error: logger.exception(f'!!! ignoring exception: {error}') - def on_att_handle_value_confirmation(self, connection, _confirmation): + def on_att_handle_value_confirmation( + self, + connection: Connection, + confirmation: att.ATT_Handle_Value_Confirmation, + ): ''' See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation ''' - if self.pending_confirmations[connection.handle] is None: + del confirmation # Unused. + if ( + pending_confirmation := self.pending_confirmations[connection.handle] + ) is None: # Not expected! logger.warning( '!!! unexpected confirmation, there is no pending indication' ) return - self.pending_confirmations[connection.handle].set_result(None) + pending_confirmation.set_result(None) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 4af609a..55c6c6d 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -58,7 +58,7 @@ from bumble.transport.common import AsyncPipeSink from bumble.core import UUID from bumble.att import ( Attribute, - ATT_EXCHANGE_MTU_REQUEST, + Opcode, ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_PDU, ATT_Error, @@ -103,7 +103,7 @@ def test_UUID(): # ----------------------------------------------------------------------------- def test_ATT_Error_Response(): pdu = ATT_Error_Response( - request_opcode_in_error=ATT_EXCHANGE_MTU_REQUEST, + request_opcode_in_error=Opcode.ATT_EXCHANGE_MTU_REQUEST, attribute_handle_in_error=0x0000, error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR, )