mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
Add support for ATT permissions on server-side
This commit is contained in:
@@ -25,12 +25,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import struct
|
import struct
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type, TYPE_CHECKING
|
||||||
|
|
||||||
from bumble.core import UUID, name_or_number
|
from bumble.core import UUID, name_or_number
|
||||||
from bumble.hci import HCI_Object, key_with_value
|
from bumble.hci import HCI_Object, key_with_value
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bumble.device import Connection
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
@@ -749,7 +751,25 @@ class Attribute(EventEmitter):
|
|||||||
def decode_value(self, value_bytes):
|
def decode_value(self, value_bytes):
|
||||||
return value_bytes
|
return value_bytes
|
||||||
|
|
||||||
def read_value(self, connection):
|
def read_value(self, connection: Connection):
|
||||||
|
if (
|
||||||
|
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
||||||
|
) and not connection.encryption:
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
||||||
|
) and not connection.authenticated:
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
|
||||||
|
# TODO: handle authorization better
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
|
||||||
if read := getattr(self.value, 'read', None):
|
if read := getattr(self.value, 'read', None):
|
||||||
try:
|
try:
|
||||||
value = read(connection) # pylint: disable=not-callable
|
value = read(connection) # pylint: disable=not-callable
|
||||||
@@ -762,7 +782,25 @@ class Attribute(EventEmitter):
|
|||||||
|
|
||||||
return self.encode_value(value)
|
return self.encode_value(value)
|
||||||
|
|
||||||
def write_value(self, connection, value_bytes):
|
def write_value(self, connection: Connection, value_bytes):
|
||||||
|
if (
|
||||||
|
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||||
|
) and not connection.encryption:
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
|
||||||
|
) and not connection.authenticated:
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
|
||||||
|
# TODO: handle authorization better
|
||||||
|
raise ATT_Error(
|
||||||
|
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||||
|
)
|
||||||
|
|
||||||
value = self.decode_value(value_bytes)
|
value = self.decode_value(value_bytes)
|
||||||
|
|
||||||
if write := getattr(self.value, 'write', None):
|
if write := getattr(self.value, 'write', None):
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ from .att import (
|
|||||||
from .gatt import (
|
from .gatt import (
|
||||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
|
||||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||||
GATT_REQUEST_TIMEOUT,
|
GATT_REQUEST_TIMEOUT,
|
||||||
@@ -543,8 +542,6 @@ class Server(EventEmitter):
|
|||||||
if attribute.handle >= request.starting_handle
|
if attribute.handle >= request.starting_handle
|
||||||
and attribute.handle <= request.ending_handle
|
and attribute.handle <= request.ending_handle
|
||||||
):
|
):
|
||||||
# TODO: check permissions
|
|
||||||
|
|
||||||
this_uuid_size = len(attribute.type.to_pdu_bytes())
|
this_uuid_size = len(attribute.type.to_pdu_bytes())
|
||||||
|
|
||||||
if attributes:
|
if attributes:
|
||||||
@@ -638,6 +635,13 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
pdu_space_available = connection.att_mtu - 2
|
pdu_space_available = connection.att_mtu - 2
|
||||||
|
|
||||||
|
response = ATT_Error_Response(
|
||||||
|
request_opcode_in_error=request.op_code,
|
||||||
|
attribute_handle_in_error=request.starting_handle,
|
||||||
|
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute
|
attribute
|
||||||
@@ -647,10 +651,21 @@ class Server(EventEmitter):
|
|||||||
and attribute.handle <= request.ending_handle
|
and attribute.handle <= request.ending_handle
|
||||||
and pdu_space_available
|
and pdu_space_available
|
||||||
):
|
):
|
||||||
# TODO: check permissions
|
|
||||||
|
try:
|
||||||
|
attribute_value = attribute.read_value(connection)
|
||||||
|
except ATT_Error as error:
|
||||||
|
# If the first attribute is unreadable, return an error
|
||||||
|
# Otherwise return attributes up to this point
|
||||||
|
if not attributes:
|
||||||
|
response = ATT_Error_Response(
|
||||||
|
request_opcode_in_error=request.op_code,
|
||||||
|
attribute_handle_in_error=attribute.handle,
|
||||||
|
error_code=error.error_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Check the attribute value size
|
# Check the attribute value size
|
||||||
attribute_value = attribute.read_value(connection)
|
|
||||||
max_attribute_size = min(connection.att_mtu - 4, 253)
|
max_attribute_size = min(connection.att_mtu - 4, 253)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
@@ -676,11 +691,7 @@ class Server(EventEmitter):
|
|||||||
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
|
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = ATT_Error_Response(
|
logging.warning(f"not found {request}")
|
||||||
request_opcode_in_error=request.op_code,
|
|
||||||
attribute_handle_in_error=request.starting_handle,
|
|
||||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.send_response(connection, response)
|
self.send_response(connection, response)
|
||||||
|
|
||||||
@@ -690,10 +701,17 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
try:
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
value_size = min(connection.att_mtu - 1, len(value))
|
except ATT_Error as error:
|
||||||
response = ATT_Read_Response(attribute_value=value[:value_size])
|
response = ATT_Error_Response(
|
||||||
|
request_opcode_in_error=request.op_code,
|
||||||
|
attribute_handle_in_error=request.attribute_handle,
|
||||||
|
error_code=error.error_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
value_size = min(connection.att_mtu - 1, len(value))
|
||||||
|
response = ATT_Read_Response(attribute_value=value[:value_size])
|
||||||
else:
|
else:
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error=request.op_code,
|
request_opcode_in_error=request.op_code,
|
||||||
@@ -708,29 +726,36 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
try:
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
if request.value_offset > len(value):
|
except ATT_Error as error:
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error=request.op_code,
|
request_opcode_in_error=request.op_code,
|
||||||
attribute_handle_in_error=request.attribute_handle,
|
attribute_handle_in_error=request.attribute_handle,
|
||||||
error_code=ATT_INVALID_OFFSET_ERROR,
|
error_code=error.error_code,
|
||||||
)
|
|
||||||
elif len(value) <= connection.att_mtu - 1:
|
|
||||||
response = ATT_Error_Response(
|
|
||||||
request_opcode_in_error=request.op_code,
|
|
||||||
attribute_handle_in_error=request.attribute_handle,
|
|
||||||
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
part_size = min(
|
if request.value_offset > len(value):
|
||||||
connection.att_mtu - 1, len(value) - request.value_offset
|
response = ATT_Error_Response(
|
||||||
)
|
request_opcode_in_error=request.op_code,
|
||||||
response = ATT_Read_Blob_Response(
|
attribute_handle_in_error=request.attribute_handle,
|
||||||
part_attribute_value=value[
|
error_code=ATT_INVALID_OFFSET_ERROR,
|
||||||
request.value_offset : request.value_offset + part_size
|
)
|
||||||
]
|
elif len(value) <= connection.att_mtu - 1:
|
||||||
)
|
response = ATT_Error_Response(
|
||||||
|
request_opcode_in_error=request.op_code,
|
||||||
|
attribute_handle_in_error=request.attribute_handle,
|
||||||
|
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
part_size = min(
|
||||||
|
connection.att_mtu - 1, len(value) - request.value_offset
|
||||||
|
)
|
||||||
|
response = ATT_Read_Blob_Response(
|
||||||
|
part_attribute_value=value[
|
||||||
|
request.value_offset : request.value_offset + part_size
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error=request.op_code,
|
request_opcode_in_error=request.op_code,
|
||||||
@@ -746,7 +771,6 @@ class Server(EventEmitter):
|
|||||||
if request.attribute_group_type not in (
|
if request.attribute_group_type not in (
|
||||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
|
||||||
):
|
):
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error=request.op_code,
|
request_opcode_in_error=request.op_code,
|
||||||
@@ -766,8 +790,10 @@ class Server(EventEmitter):
|
|||||||
and attribute.handle <= request.ending_handle
|
and attribute.handle <= request.ending_handle
|
||||||
and pdu_space_available
|
and pdu_space_available
|
||||||
):
|
):
|
||||||
# Check the attribute value size
|
# No need to catch permission errors here, since these attributes
|
||||||
|
# must all be world-readable
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
|
# Check the attribute value size
|
||||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
|
|||||||
Reference in New Issue
Block a user