mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Add EATT Support
This commit is contained in:
@@ -298,6 +298,7 @@ class Speaker:
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
eatt_enabled=True,
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
|
||||
@@ -34,10 +34,13 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
ClassVar,
|
||||
Generic,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from bumble import hci, utils
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from bumble import hci, l2cap, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, InvalidOperationError, ProtocolError
|
||||
from bumble.hci import HCI_Object
|
||||
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
|
||||
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
|
||||
|
||||
|
||||
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
|
||||
return isinstance(bearer, EnhancedBearer)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -58,6 +69,7 @@ _T = TypeVar('_T')
|
||||
|
||||
ATT_CID = 0x04
|
||||
ATT_PSM = 0x001F
|
||||
EATT_PSM = 0x0027
|
||||
|
||||
class Opcode(hci.SpecableEnum):
|
||||
ATT_ERROR_RESPONSE = 0x01
|
||||
@@ -780,6 +792,43 @@ class AttributeValue(Generic[_T]):
|
||||
return self._write(connection, value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeValueV2(Generic[_T]):
|
||||
'''
|
||||
Attribute value compatible with enhanced bearers.
|
||||
|
||||
The only difference between AttributeValue and AttributeValueV2 is that the actual
|
||||
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
|
||||
will be passed into read and write callbacks in V2, while in V1 it is always
|
||||
the base ACL connection.
|
||||
|
||||
This is only required when attributes must distinguish bearers, otherwise normal
|
||||
`AttributeValue` objects are also applicable in enhanced bearers.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
|
||||
write: (
|
||||
Callable[[Bearer, _T], Awaitable[None]]
|
||||
| Callable[[Bearer, _T], None]
|
||||
| None
|
||||
) = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
|
||||
if self._read is None:
|
||||
raise InvalidOperationError('AttributeValue has no read function')
|
||||
return self._read(bearer)
|
||||
|
||||
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
|
||||
if self._write is None:
|
||||
raise InvalidOperationError('AttributeValue has no write function')
|
||||
return self._write(bearer, value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
class Permissions(enum.IntFlag):
|
||||
@@ -855,7 +904,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
def decode_value(self, value: bytes) -> _T:
|
||||
return value # type: ignore
|
||||
|
||||
async def read_value(self, connection: Connection) -> bytes:
|
||||
async def read_value(self, bearer: Bearer) -> bytes:
|
||||
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
|
||||
if (
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
@@ -890,6 +940,17 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
elif isinstance(self.value, AttributeValueV2):
|
||||
try:
|
||||
read_value = self.value.read(bearer)
|
||||
if inspect.isawaitable(read_value):
|
||||
value = await read_value
|
||||
else:
|
||||
value = read_value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
else:
|
||||
value = self.value
|
||||
|
||||
@@ -897,7 +958,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
|
||||
return b'' if value is None else self.encode_value(value)
|
||||
|
||||
async def write_value(self, connection: Connection, value: bytes) -> None:
|
||||
async def write_value(self, bearer: Bearer, value: bytes) -> None:
|
||||
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
|
||||
if (
|
||||
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
@@ -931,6 +993,15 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
elif isinstance(self.value, AttributeValueV2):
|
||||
try:
|
||||
result = self.value.write(bearer, decoded_value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
) from error
|
||||
else:
|
||||
self.value = decoded_value
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from typing import (
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import (
|
||||
att,
|
||||
core,
|
||||
data_types,
|
||||
gatt,
|
||||
@@ -53,7 +54,6 @@ from bumble import (
|
||||
smp,
|
||||
utils,
|
||||
)
|
||||
from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
@@ -1743,7 +1743,6 @@ class Connection(utils.CompositeEventEmitter):
|
||||
EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure"
|
||||
EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update"
|
||||
EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure"
|
||||
EVENT_CONNECTION_ATT_MTU_UPDATE = "connection_att_mtu_update"
|
||||
EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change"
|
||||
EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = (
|
||||
"channel_sounding_capabilities_failure"
|
||||
@@ -1846,7 +1845,7 @@ class Connection(utils.CompositeEventEmitter):
|
||||
self.encryption_key_size = 0
|
||||
self.authenticated = False
|
||||
self.sc = False
|
||||
self.att_mtu = ATT_DEFAULT_MTU
|
||||
self.att_mtu = att.ATT_DEFAULT_MTU
|
||||
self.data_length = DEVICE_DEFAULT_DATA_LENGTH
|
||||
self.gatt_client = gatt_client.Client(self) # Per-connection client
|
||||
self.gatt_server = (
|
||||
@@ -1996,6 +1995,15 @@ class Connection(utils.CompositeEventEmitter):
|
||||
self.peer_le_features = await self.device.get_remote_le_features(self)
|
||||
return self.peer_le_features
|
||||
|
||||
def on_att_mtu_update(self, mtu: int):
|
||||
logger.debug(
|
||||
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
|
||||
f'{self.peer_address} as {self.role_name}, '
|
||||
f'{mtu}'
|
||||
)
|
||||
self.att_mtu = mtu
|
||||
self.emit(self.EVENT_CONNECTION_ATT_MTU_UPDATE)
|
||||
|
||||
@property
|
||||
def data_packet_queue(self) -> DataPacketQueue | None:
|
||||
return self.device.host.get_data_packet_queue(self.handle)
|
||||
@@ -2079,6 +2087,7 @@ class DeviceConfiguration:
|
||||
l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION,
|
||||
l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE,
|
||||
)
|
||||
eatt_enabled: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.gatt_services: list[dict[str, Any]] = []
|
||||
@@ -2497,7 +2506,10 @@ class Device(utils.CompositeEventEmitter):
|
||||
add_gap_service=config.gap_service_enabled,
|
||||
add_gatt_service=config.gatt_service_enabled,
|
||||
)
|
||||
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu)
|
||||
self.l2cap_channel_manager.register_fixed_channel(att.ATT_CID, self.on_gatt_pdu)
|
||||
|
||||
if self.config.eatt_enabled:
|
||||
self.gatt_server.register_eatt()
|
||||
|
||||
# Forward some events
|
||||
utils.setup_event_forwarding(
|
||||
@@ -5140,7 +5152,11 @@ class Device(utils.CompositeEventEmitter):
|
||||
if add_gap_service:
|
||||
self.gatt_server.add_service(GenericAccessService(self.name))
|
||||
if add_gatt_service:
|
||||
self.gatt_service = gatt_service.GenericAttributeProfileService()
|
||||
self.gatt_service = gatt_service.GenericAttributeProfileService(
|
||||
gatt.ServerSupportedFeatures.EATT_SUPPORTED
|
||||
if self.config.eatt_enabled
|
||||
else None
|
||||
)
|
||||
self.gatt_server.add_service(self.gatt_service)
|
||||
|
||||
async def notify_subscriber(
|
||||
@@ -6240,17 +6256,6 @@ class Device(utils.CompositeEventEmitter):
|
||||
)
|
||||
connection.emit(connection.EVENT_LE_SUBRATE_CHANGE)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
|
||||
logger.debug(
|
||||
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address} as {connection.role_name}, '
|
||||
f'{att_mtu}'
|
||||
)
|
||||
connection.att_mtu = att_mtu
|
||||
connection.emit(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_connection_data_length_change(
|
||||
@@ -6437,7 +6442,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
@with_connection_from_handle
|
||||
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
|
||||
# Parse the L2CAP payload into an ATT PDU object
|
||||
att_pdu = ATT_PDU.from_bytes(pdu)
|
||||
att_pdu = att.ATT_PDU.from_bytes(pdu)
|
||||
|
||||
# Conveniently, even-numbered op codes are client->server and
|
||||
# odd-numbered ones are server->client
|
||||
|
||||
@@ -31,7 +31,7 @@ import struct
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
from bumble.att import Attribute, AttributeValue, AttributeValueV2
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, BaseBumbleError
|
||||
|
||||
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
|
||||
value_str = '<dynamic>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from collections.abc import Callable, Iterable
|
||||
@@ -35,9 +36,10 @@ from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from bumble import att, core, utils
|
||||
from bumble import att, core, l2cap, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, InvalidStateError
|
||||
from bumble.gatt import (
|
||||
@@ -54,12 +56,12 @@ from bumble.gatt import (
|
||||
)
|
||||
from bumble.hci import HCI_Constant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble import device as device_module
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing
|
||||
# -----------------------------------------------------------------------------
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -267,8 +269,8 @@ class Client:
|
||||
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
|
||||
pending_request: att.ATT_PDU | None
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
def __init__(self, bearer: att.Bearer) -> None:
|
||||
self.bearer = bearer
|
||||
self.mtu_exchange_done = False
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request = None
|
||||
@@ -278,21 +280,78 @@ class Client:
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
|
||||
if att.is_enhanced_bearer(bearer):
|
||||
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
|
||||
self._bearer_id = (
|
||||
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
|
||||
)
|
||||
# Fill the mtu.
|
||||
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
|
||||
self.connection = bearer.connection
|
||||
else:
|
||||
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
|
||||
self._bearer_id = f'[0x{bearer.handle:04X}]'
|
||||
self.connection = bearer
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect_eatt(
|
||||
cls,
|
||||
connection: device_module.Connection,
|
||||
spec: l2cap.LeCreditBasedChannelSpec | None = None,
|
||||
) -> Client: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def connect_eatt(
|
||||
cls,
|
||||
connection: device_module.Connection,
|
||||
spec: l2cap.LeCreditBasedChannelSpec | None = None,
|
||||
count: int = 1,
|
||||
) -> list[Client]: ...
|
||||
|
||||
@classmethod
|
||||
async def connect_eatt(
|
||||
cls,
|
||||
connection: device_module.Connection,
|
||||
spec: l2cap.LeCreditBasedChannelSpec | None = None,
|
||||
count: int = 1,
|
||||
) -> list[Client] | Client:
|
||||
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
|
||||
connection,
|
||||
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
|
||||
count,
|
||||
)
|
||||
|
||||
def on_pdu(client: Client, pdu: bytes):
|
||||
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
|
||||
|
||||
clients = [cls(channel) for channel in channels]
|
||||
for channel, client in zip(channels, clients):
|
||||
channel.sink = functools.partial(on_pdu, client)
|
||||
channel.att_mtu = att.ATT_DEFAULT_MTU
|
||||
return clients[0] if count == 1 else clients
|
||||
|
||||
@property
|
||||
def mtu(self) -> int:
|
||||
return self.bearer.att_mtu
|
||||
|
||||
@mtu.setter
|
||||
def mtu(self, value: int) -> None:
|
||||
self.bearer.on_att_mtu_update(value)
|
||||
|
||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
|
||||
if att.is_enhanced_bearer(self.bearer):
|
||||
self.bearer.write(pdu)
|
||||
else:
|
||||
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
|
||||
|
||||
async def send_command(self, command: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
|
||||
self.send_gatt_pdu(bytes(command))
|
||||
|
||||
async def send_request(self, request: att.ATT_PDU):
|
||||
logger.debug(
|
||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||
)
|
||||
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
|
||||
|
||||
# Wait until we can send (only one pending command at a time for the connection)
|
||||
response = None
|
||||
@@ -321,10 +380,7 @@ class Client:
|
||||
def send_confirmation(
|
||||
self, confirmation: att.ATT_Handle_Value_Confirmation
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
|
||||
self.send_gatt_pdu(bytes(confirmation))
|
||||
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
@@ -336,7 +392,7 @@ class Client:
|
||||
|
||||
# We can only send one request per connection
|
||||
if self.mtu_exchange_done:
|
||||
return self.connection.att_mtu
|
||||
return self.mtu
|
||||
|
||||
# Send the request
|
||||
self.mtu_exchange_done = True
|
||||
@@ -347,9 +403,9 @@ class Client:
|
||||
raise att.ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
# Compute the final MTU
|
||||
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
|
||||
self.mtu = min(mtu, response.server_rx_mtu)
|
||||
|
||||
return self.connection.att_mtu
|
||||
return self.mtu
|
||||
|
||||
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
|
||||
return [service for service in self.services if service.uuid == uuid]
|
||||
@@ -942,7 +998,7 @@ class Client:
|
||||
# If the value is the max size for the MTU, try to read more unless the caller
|
||||
# specifically asked not to do that
|
||||
attribute_value = response.attribute_value
|
||||
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
|
||||
if not no_long_read and len(attribute_value) == self.mtu - 1:
|
||||
logger.debug('using READ BLOB to get the rest of the value')
|
||||
offset = len(attribute_value)
|
||||
while True:
|
||||
@@ -966,7 +1022,7 @@ class Client:
|
||||
part = response.part_attribute_value
|
||||
attribute_value += part
|
||||
|
||||
if len(part) < self.connection.att_mtu - 1:
|
||||
if len(part) < self.mtu - 1:
|
||||
break
|
||||
|
||||
offset += len(part)
|
||||
@@ -1062,14 +1118,13 @@ class Client:
|
||||
)
|
||||
)
|
||||
|
||||
def on_disconnection(self, _) -> None:
|
||||
def on_disconnection(self, *args) -> None:
|
||||
del args # unused.
|
||||
if self.pending_response and not self.pending_response.done():
|
||||
self.pending_response.cancel()
|
||||
|
||||
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||
)
|
||||
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
|
||||
if att_pdu.op_code in att.ATT_RESPONSES:
|
||||
if self.pending_request is None:
|
||||
# Not expected!
|
||||
@@ -1099,8 +1154,7 @@ class Client:
|
||||
else:
|
||||
logger.warning(
|
||||
color(
|
||||
'--- Ignoring GATT Response from '
|
||||
f'[0x{self.connection.handle:04X}]: ',
|
||||
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
|
||||
'red',
|
||||
)
|
||||
+ str(att_pdu)
|
||||
|
||||
@@ -32,9 +32,8 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from bumble import att, utils
|
||||
from bumble import att, core, l2cap, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
@@ -44,14 +43,13 @@ from bumble.gatt import (
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
CharacteristicDeclaration,
|
||||
CharacteristicValue,
|
||||
Descriptor,
|
||||
IncludedServiceDeclaration,
|
||||
Service,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.device import Device
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -65,6 +63,18 @@ logger = logging.getLogger(__name__)
|
||||
GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _bearer_id(bearer: att.Bearer) -> str:
|
||||
if att.is_enhanced_bearer(bearer):
|
||||
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
|
||||
else:
|
||||
return f'[0x{bearer.handle:04X}]'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Server
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -72,9 +82,9 @@ class Server(utils.EventEmitter):
|
||||
attributes: list[att.Attribute]
|
||||
services: list[Service]
|
||||
attributes_by_handle: dict[int, att.Attribute]
|
||||
subscribers: dict[int, dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[int, asyncio.futures.Future | None]
|
||||
subscribers: dict[att.Bearer, dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
|
||||
|
||||
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
|
||||
|
||||
@@ -96,8 +106,29 @@ class Server(utils.EventEmitter):
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
|
||||
def register_eatt(
|
||||
self, spec: l2cap.LeCreditBasedChannelSpec | None = None
|
||||
) -> l2cap.LeCreditBasedChannelServer:
|
||||
def on_channel(channel: l2cap.LeCreditBasedChannel):
|
||||
logger.debug(
|
||||
"New EATT Bearer Conenction=0x%04X CID=0x%04X",
|
||||
channel.connection.handle,
|
||||
channel.source_cid,
|
||||
)
|
||||
channel.att_mtu = att.ATT_DEFAULT_MTU
|
||||
channel.sink = lambda pdu: self.on_gatt_pdu(
|
||||
channel, att.ATT_PDU.from_bytes(pdu)
|
||||
)
|
||||
|
||||
return self.device.create_l2cap_server(
|
||||
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
|
||||
)
|
||||
|
||||
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
|
||||
if att.is_enhanced_bearer(bearer):
|
||||
bearer.write(pdu)
|
||||
else:
|
||||
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
|
||||
|
||||
def next_handle(self) -> int:
|
||||
return 1 + len(self.attributes)
|
||||
@@ -138,7 +169,7 @@ class Server(utils.EventEmitter):
|
||||
None,
|
||||
)
|
||||
|
||||
def get_service_attribute(self, service_uuid: UUID) -> Service | None:
|
||||
def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
@@ -151,7 +182,7 @@ class Server(utils.EventEmitter):
|
||||
)
|
||||
|
||||
def get_characteristic_attributes(
|
||||
self, service_uuid: UUID, characteristic_uuid: UUID
|
||||
self, service_uuid: core.UUID, characteristic_uuid: core.UUID
|
||||
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
|
||||
service_handle = self.get_service_attribute(service_uuid)
|
||||
if not service_handle:
|
||||
@@ -176,7 +207,10 @@ class Server(utils.EventEmitter):
|
||||
)
|
||||
|
||||
def get_descriptor_attribute(
|
||||
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
|
||||
self,
|
||||
service_uuid: core.UUID,
|
||||
characteristic_uuid: core.UUID,
|
||||
descriptor_uuid: core.UUID,
|
||||
) -> Descriptor | None:
|
||||
characteristics = self.get_characteristic_attributes(
|
||||
service_uuid, characteristic_uuid
|
||||
@@ -257,14 +291,7 @@ class Server(utils.EventEmitter):
|
||||
Descriptor(
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
att.Attribute.READABLE | att.Attribute.WRITEABLE,
|
||||
CharacteristicValue(
|
||||
read=lambda connection, characteristic=characteristic: self.read_cccd(
|
||||
connection, characteristic
|
||||
),
|
||||
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
|
||||
connection, characteristic, value
|
||||
),
|
||||
),
|
||||
self.make_descriptor_value(characteristic),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -280,10 +307,21 @@ class Server(utils.EventEmitter):
|
||||
for service in services:
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(
|
||||
self, connection: Connection, characteristic: Characteristic
|
||||
) -> bytes:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
def make_descriptor_value(
|
||||
self, characteristic: Characteristic
|
||||
) -> att.AttributeValueV2:
|
||||
# It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
|
||||
return att.AttributeValueV2(
|
||||
lambda bearer, characteristic=characteristic: self.read_cccd(
|
||||
bearer, characteristic
|
||||
),
|
||||
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
|
||||
bearer, characteristic, value
|
||||
),
|
||||
)
|
||||
|
||||
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
|
||||
subscribers = self.subscribers.get(bearer)
|
||||
cccd = None
|
||||
if subscribers:
|
||||
cccd = subscribers.get(characteristic.handle)
|
||||
@@ -292,12 +330,12 @@ class Server(utils.EventEmitter):
|
||||
|
||||
def write_cccd(
|
||||
self,
|
||||
connection: Connection,
|
||||
bearer: att.Bearer,
|
||||
characteristic: Characteristic,
|
||||
value: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'Subscription update for connection=0x{connection.handle:04X}, '
|
||||
f'Subscription update for connection={_bearer_id(bearer)}, '
|
||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||
)
|
||||
|
||||
@@ -306,41 +344,60 @@ class Server(utils.EventEmitter):
|
||||
logger.warning('CCCD value not 2 bytes long')
|
||||
return
|
||||
|
||||
cccds = self.subscribers.setdefault(connection.handle, {})
|
||||
cccds = self.subscribers.setdefault(bearer, {})
|
||||
cccds[characteristic.handle] = value
|
||||
logger.debug(f'CCCDs: {cccds}')
|
||||
notify_enabled = value[0] & 0x01 != 0
|
||||
indicate_enabled = value[0] & 0x02 != 0
|
||||
characteristic.emit(
|
||||
characteristic.EVENT_SUBSCRIPTION,
|
||||
connection,
|
||||
bearer,
|
||||
notify_enabled,
|
||||
indicate_enabled,
|
||||
)
|
||||
self.emit(
|
||||
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
|
||||
connection,
|
||||
bearer,
|
||||
characteristic,
|
||||
notify_enabled,
|
||||
indicate_enabled,
|
||||
)
|
||||
|
||||
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(response))
|
||||
def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
|
||||
self.send_gatt_pdu(bearer, bytes(response))
|
||||
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
return await self._notify_single_subscriber(bearer, attribute, value, force)
|
||||
else:
|
||||
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
|
||||
bearers = [
|
||||
channel
|
||||
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
|
||||
bearer.handle, {}
|
||||
).values()
|
||||
if channel.psm == att.EATT_PSM
|
||||
] + [bearer]
|
||||
for bearer in bearers:
|
||||
await self._notify_single_subscriber(bearer, attribute, value, force)
|
||||
|
||||
async def _notify_single_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
subscribers = self.subscribers.get(bearer)
|
||||
if not subscribers:
|
||||
logger.debug('not notifying, no subscribers')
|
||||
return
|
||||
@@ -356,34 +413,53 @@ class Server(utils.EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
await attribute.read_value(connection)
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > connection.att_mtu - 3:
|
||||
value = value[: connection.att_mtu - 3]
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
|
||||
# Notify
|
||||
notification = att.ATT_Handle_Value_Notification(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
)
|
||||
logger.debug(
|
||||
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
||||
self.send_gatt_pdu(bearer, bytes(notification))
|
||||
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
return await self._notify_single_subscriber(bearer, attribute, value, force)
|
||||
else:
|
||||
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
|
||||
bearers = [
|
||||
channel
|
||||
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
|
||||
bearer.handle, {}
|
||||
).values()
|
||||
if channel.psm == att.EATT_PSM
|
||||
] + [bearer]
|
||||
for bearer in bearers:
|
||||
await self._indicate_single_bearer(bearer, attribute, value, force)
|
||||
|
||||
async def _indicate_single_bearer(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
subscribers = self.subscribers.get(bearer)
|
||||
if not subscribers:
|
||||
logger.debug('not indicating, no subscribers')
|
||||
return
|
||||
@@ -399,40 +475,38 @@ class Server(utils.EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
await attribute.read_value(connection)
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > connection.att_mtu - 3:
|
||||
value = value[: connection.att_mtu - 3]
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
|
||||
# Indicate
|
||||
indication = att.ATT_Handle_Value_Indication(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
)
|
||||
logger.debug(
|
||||
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
|
||||
)
|
||||
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
||||
|
||||
# Wait until we can send (only one pending indication at a time per connection)
|
||||
async with self.indication_semaphores[connection.handle]:
|
||||
assert self.pending_confirmations[connection.handle] is None
|
||||
async with self.indication_semaphores[bearer]:
|
||||
assert self.pending_confirmations[bearer] is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
pending_confirmation = self.pending_confirmations[connection.handle] = (
|
||||
pending_confirmation = self.pending_confirmations[bearer] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, bytes(indication))
|
||||
self.send_gatt_pdu(bearer, bytes(indication))
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
||||
finally:
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
self.pending_confirmations[bearer] = None
|
||||
|
||||
async def _notify_or_indicate_subscribers(
|
||||
self,
|
||||
@@ -441,24 +515,24 @@ class Server(utils.EventEmitter):
|
||||
value: bytes | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection
|
||||
for connection in [
|
||||
self.device.lookup_connection(connection_handle)
|
||||
for (connection_handle, subscribers) in self.subscribers.items()
|
||||
if force or subscribers.get(attribute.handle)
|
||||
]
|
||||
if connection is not None
|
||||
# Get all the bearers for which there's at least one subscription
|
||||
bearers: list[att.Bearer] = [
|
||||
bearer
|
||||
for bearer, subscribers in self.subscribers.items()
|
||||
if force or subscribers.get(attribute.handle)
|
||||
]
|
||||
|
||||
# Indicate or notify for each connection
|
||||
if connections:
|
||||
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
|
||||
if bearers:
|
||||
coroutine = (
|
||||
self._indicate_single_bearer
|
||||
if indicate
|
||||
else self._notify_single_subscriber
|
||||
)
|
||||
await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(coroutine(connection, attribute, value, force))
|
||||
for connection in connections
|
||||
asyncio.create_task(coroutine(bearer, attribute, value, force))
|
||||
for bearer in bearers
|
||||
]
|
||||
)
|
||||
|
||||
@@ -480,21 +554,18 @@ class Server(utils.EventEmitter):
|
||||
):
|
||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection: Connection) -> None:
|
||||
if connection.handle in self.subscribers:
|
||||
del self.subscribers[connection.handle]
|
||||
if connection.handle in self.indication_semaphores:
|
||||
del self.indication_semaphores[connection.handle]
|
||||
if connection.handle in self.pending_confirmations:
|
||||
del self.pending_confirmations[connection.handle]
|
||||
def on_disconnection(self, bearer: att.Bearer) -> None:
|
||||
self.subscribers.pop(bearer, None)
|
||||
self.indication_semaphores.pop(bearer, None)
|
||||
self.pending_confirmations.pop(bearer, None)
|
||||
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||
def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is not None:
|
||||
try:
|
||||
handler(connection, att_pdu)
|
||||
handler(bearer, att_pdu)
|
||||
except att.ATT_Error as error:
|
||||
logger.debug(f'normal exception returned by handler: {error}')
|
||||
response = att.ATT_Error_Response(
|
||||
@@ -502,7 +573,7 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=error.att_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
response = att.ATT_Error_Response(
|
||||
@@ -510,18 +581,18 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=0x0000,
|
||||
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
raise
|
||||
else:
|
||||
# No specific handler registered
|
||||
if att_pdu.op_code in att.ATT_REQUESTS:
|
||||
# Invoke the generic handler
|
||||
self.on_att_request(connection, att_pdu)
|
||||
self.on_att_request(bearer, att_pdu)
|
||||
else:
|
||||
# Just ignore
|
||||
logger.warning(
|
||||
color(
|
||||
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
|
||||
f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
|
||||
'red',
|
||||
)
|
||||
+ str(att_pdu)
|
||||
@@ -530,13 +601,14 @@ class Server(utils.EventEmitter):
|
||||
#######################################################
|
||||
# ATT handlers
|
||||
#######################################################
|
||||
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
|
||||
def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
|
||||
'''
|
||||
Handler for requests without a more specific handler
|
||||
'''
|
||||
logger.warning(
|
||||
color(
|
||||
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
|
||||
f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
|
||||
'red',
|
||||
)
|
||||
+ str(pdu)
|
||||
)
|
||||
@@ -545,29 +617,28 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=0x0000,
|
||||
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
def on_att_exchange_mtu_request(
|
||||
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
||||
'''
|
||||
self.send_response(
|
||||
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
|
||||
bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
|
||||
)
|
||||
|
||||
# Compute the final MTU
|
||||
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
|
||||
mtu = min(self.max_mtu, request.client_rx_mtu)
|
||||
|
||||
# Notify the device
|
||||
self.device.on_connection_att_mtu_update(connection.handle, mtu)
|
||||
bearer.on_att_mtu_update(mtu)
|
||||
else:
|
||||
logger.warning('invalid client_rx_mtu received, MTU not changed')
|
||||
|
||||
def on_att_find_information_request(
|
||||
self, connection: Connection, request: att.ATT_Find_Information_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
|
||||
@@ -580,7 +651,7 @@ class Server(utils.EventEmitter):
|
||||
or request.starting_handle > request.ending_handle
|
||||
):
|
||||
self.send_response(
|
||||
connection,
|
||||
bearer,
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
@@ -590,7 +661,7 @@ class Server(utils.EventEmitter):
|
||||
return
|
||||
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
pdu_space_available = bearer.att_mtu - 2
|
||||
attributes: list[att.Attribute] = []
|
||||
uuid_size = 0
|
||||
for attribute in (
|
||||
@@ -632,18 +703,18 @@ class Server(utils.EventEmitter):
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(
|
||||
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
pdu_space_available = bearer.att_mtu - 2
|
||||
attributes = []
|
||||
response: att.ATT_PDU
|
||||
async for attribute in (
|
||||
@@ -652,7 +723,7 @@ class Server(utils.EventEmitter):
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
and attribute.type == request.attribute_type
|
||||
and (await attribute.read_value(connection)) == request.attribute_value
|
||||
and (await attribute.read_value(bearer)) == request.attribute_value
|
||||
and pdu_space_available >= 4
|
||||
):
|
||||
# TODO: check permissions
|
||||
@@ -688,17 +759,17 @@ class Server(utils.EventEmitter):
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(
|
||||
self, connection: Connection, request: att.ATT_Read_By_Type_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
pdu_space_available = bearer.att_mtu - 2
|
||||
|
||||
response: att.ATT_PDU = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -716,7 +787,7 @@ class Server(utils.EventEmitter):
|
||||
and pdu_space_available
|
||||
):
|
||||
try:
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(bearer)
|
||||
except att.ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
@@ -729,7 +800,7 @@ class Server(utils.EventEmitter):
|
||||
break
|
||||
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 4, 253)
|
||||
max_attribute_size = min(bearer.att_mtu - 4, 253)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
# We need to truncate
|
||||
attribute_value = attribute_value[:max_attribute_size]
|
||||
@@ -756,11 +827,11 @@ class Server(utils.EventEmitter):
|
||||
else:
|
||||
logging.debug(f"not found {request}")
|
||||
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(
|
||||
self, connection: Connection, request: att.ATT_Read_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Read_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
@@ -769,7 +840,7 @@ class Server(utils.EventEmitter):
|
||||
response: att.ATT_PDU
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = await attribute.read_value(connection)
|
||||
value = await attribute.read_value(bearer)
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -777,7 +848,7 @@ class Server(utils.EventEmitter):
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
value_size = min(connection.att_mtu - 1, len(value))
|
||||
value_size = min(bearer.att_mtu - 1, len(value))
|
||||
response = att.ATT_Read_Response(attribute_value=value[:value_size])
|
||||
else:
|
||||
response = att.ATT_Error_Response(
|
||||
@@ -785,11 +856,11 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(
|
||||
self, connection: Connection, request: att.ATT_Read_Blob_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
@@ -798,7 +869,7 @@ class Server(utils.EventEmitter):
|
||||
response: att.ATT_PDU
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = await attribute.read_value(connection)
|
||||
value = await attribute.read_value(bearer)
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -812,7 +883,7 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=att.ATT_INVALID_OFFSET_ERROR,
|
||||
)
|
||||
elif len(value) <= connection.att_mtu - 1:
|
||||
elif len(value) <= bearer.att_mtu - 1:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
@@ -820,7 +891,7 @@ class Server(utils.EventEmitter):
|
||||
)
|
||||
else:
|
||||
part_size = min(
|
||||
connection.att_mtu - 1, len(value) - request.value_offset
|
||||
bearer.att_mtu - 1, len(value) - request.value_offset
|
||||
)
|
||||
response = att.ATT_Read_Blob_Response(
|
||||
part_attribute_value=value[
|
||||
@@ -833,11 +904,11 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(
|
||||
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
@@ -852,10 +923,10 @@ class Server(utils.EventEmitter):
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
return
|
||||
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
pdu_space_available = bearer.att_mtu - 2
|
||||
attributes: list[tuple[int, int, bytes]] = []
|
||||
for attribute in (
|
||||
attribute
|
||||
@@ -867,9 +938,9 @@ class Server(utils.EventEmitter):
|
||||
):
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(bearer)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
max_attribute_size = min(bearer.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
# We need to truncate
|
||||
attribute_value = attribute_value[:max_attribute_size]
|
||||
@@ -904,11 +975,11 @@ class Server(utils.EventEmitter):
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(
|
||||
self, connection: Connection, request: att.ATT_Write_Request
|
||||
self, bearer: att.Bearer, request: att.ATT_Write_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
@@ -918,7 +989,7 @@ class Server(utils.EventEmitter):
|
||||
attribute = self.get_attribute(request.attribute_handle)
|
||||
if attribute is None:
|
||||
self.send_response(
|
||||
connection,
|
||||
bearer,
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
@@ -932,7 +1003,7 @@ class Server(utils.EventEmitter):
|
||||
# Check the request parameters
|
||||
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
|
||||
self.send_response(
|
||||
connection,
|
||||
bearer,
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
@@ -944,7 +1015,7 @@ class Server(utils.EventEmitter):
|
||||
response: att.ATT_PDU
|
||||
try:
|
||||
# Accept the value
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(bearer, request.attribute_value)
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -954,11 +1025,11 @@ class Server(utils.EventEmitter):
|
||||
else:
|
||||
# Done
|
||||
response = att.ATT_Write_Response()
|
||||
self.send_response(connection, response)
|
||||
self.send_response(bearer, response)
|
||||
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(
|
||||
self, connection: Connection, request: att.ATT_Write_Command
|
||||
self, bearer: att.Bearer, request: att.ATT_Write_Command
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
@@ -977,22 +1048,20 @@ class Server(utils.EventEmitter):
|
||||
|
||||
# Accept the value
|
||||
try:
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(bearer, request.attribute_value)
|
||||
except Exception:
|
||||
logger.exception('!!! ignoring exception')
|
||||
|
||||
def on_att_handle_value_confirmation(
|
||||
self,
|
||||
connection: Connection,
|
||||
bearer: att.Bearer,
|
||||
confirmation: att.ATT_Handle_Value_Confirmation,
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
|
||||
'''
|
||||
del confirmation # Unused.
|
||||
if (
|
||||
pending_confirmation := self.pending_confirmations[connection.handle]
|
||||
) is None:
|
||||
if (pending_confirmation := self.pending_confirmations[bearer]) is None:
|
||||
# Not expected!
|
||||
logger.warning(
|
||||
'!!! unexpected confirmation, there is no pending indication'
|
||||
|
||||
@@ -1552,6 +1552,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
|
||||
EVENT_OPEN = "open"
|
||||
EVENT_CLOSE = "close"
|
||||
EVENT_ATT_MTU_UPDATE = "att_mtu_update"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1591,6 +1592,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
self.connection_result = None
|
||||
self.disconnection_result = None
|
||||
self.drained = asyncio.Event()
|
||||
self.att_mtu = 0 # Filled by GATT client or server later.
|
||||
|
||||
self.drained.set()
|
||||
|
||||
@@ -1821,6 +1823,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
|
||||
def on_att_mtu_update(self, mtu: int) -> None:
|
||||
self.att_mtu = mtu
|
||||
self.emit(self.EVENT_ATT_MTU_UPDATE, mtu)
|
||||
|
||||
def flush_output(self) -> None:
|
||||
self.out_queue.clear()
|
||||
self.out_sdu = None
|
||||
|
||||
@@ -19,10 +19,10 @@ import asyncio
|
||||
import sys
|
||||
|
||||
import bumble.logging
|
||||
from bumble import gatt_client
|
||||
from bumble.colors import color
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import show_services
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
@@ -34,24 +34,27 @@ class Listener(Device.Listener):
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
# pylint: disable=invalid-overridden-method
|
||||
async def on_connection(self, connection):
|
||||
async def on_connection(self, connection: Connection):
|
||||
print(f'=== Connected to {connection}')
|
||||
|
||||
# Discover all services
|
||||
print('=== Discovering services')
|
||||
peer = Peer(connection)
|
||||
await peer.discover_services()
|
||||
for service in peer.services:
|
||||
if connection.device.config.eatt_enabled:
|
||||
client = await gatt_client.Client.connect_eatt(connection)
|
||||
else:
|
||||
client = connection.gatt_client
|
||||
await client.discover_services()
|
||||
for service in client.services:
|
||||
await service.discover_characteristics()
|
||||
for characteristic in service.characteristics:
|
||||
await characteristic.discover_descriptors()
|
||||
|
||||
print('=== Services discovered')
|
||||
show_services(peer.services)
|
||||
gatt_client.show_services(client.services)
|
||||
|
||||
# Discover all attributes
|
||||
print('=== Discovering attributes')
|
||||
attributes = await peer.discover_attributes()
|
||||
attributes = await client.discover_attributes()
|
||||
for attribute in attributes:
|
||||
print(attribute)
|
||||
print('=== Attributes discovered')
|
||||
@@ -59,7 +62,7 @@ class Listener(Device.Listener):
|
||||
# Read all attributes
|
||||
for attribute in attributes:
|
||||
try:
|
||||
value = await peer.read_value(attribute)
|
||||
value = await client.read_value(attribute)
|
||||
print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green'))
|
||||
except ProtocolError as error:
|
||||
print(color(f'cannot read {attribute.handle:04X}:', 'red'), error)
|
||||
|
||||
@@ -28,6 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
|
||||
import pytest
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import gatt_client, l2cap
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_PDU,
|
||||
@@ -63,7 +64,6 @@ from bumble.gatt_adapters import (
|
||||
UTF8CharacteristicAdapter,
|
||||
UTF8CharacteristicProxyAdapter,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
|
||||
from .test_utils import Devices, TwoDevices, async_barrier
|
||||
|
||||
@@ -140,7 +140,7 @@ async def test_characteristic_encoding():
|
||||
await c.write_value(Mock(), bytes([122]))
|
||||
assert c.value == 122
|
||||
|
||||
class FooProxy(CharacteristicProxy):
|
||||
class FooProxy(gatt_client.CharacteristicProxy):
|
||||
def __init__(self, characteristic):
|
||||
super().__init__(
|
||||
characteristic.client,
|
||||
@@ -456,7 +456,7 @@ async def test_CharacteristicProxyAdapter() -> None:
|
||||
async def write_value(self, handle, value, with_response=False):
|
||||
self.value = value
|
||||
|
||||
class TestAttributeProxy(CharacteristicProxy):
|
||||
class TestAttributeProxy(gatt_client.CharacteristicProxy):
|
||||
def __init__(self, value) -> None:
|
||||
super().__init__(Client(value), 0, 0, None, 0) # type: ignore
|
||||
|
||||
@@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid():
|
||||
await peer.discover_characteristics()
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
|
||||
assert len(c) == 2
|
||||
assert isinstance(c[0], CharacteristicProxy)
|
||||
assert isinstance(c[0], gatt_client.CharacteristicProxy)
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
|
||||
assert len(c) == 1
|
||||
assert isinstance(c[0], CharacteristicProxy)
|
||||
assert isinstance(c[0], gatt_client.CharacteristicProxy)
|
||||
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
|
||||
assert len(c) == 0
|
||||
|
||||
@@ -1463,6 +1463,181 @@ async def test_write_return_error():
|
||||
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_eatt_read():
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
devices[1].gatt_server.register_eatt()
|
||||
|
||||
characteristic = Characteristic(
|
||||
'1234',
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.Permissions.READABLE,
|
||||
b'9999',
|
||||
)
|
||||
service = Service('ABCD', [characteristic])
|
||||
devices[1].add_service(service)
|
||||
|
||||
client = await gatt_client.Client.connect_eatt(devices.connections[0])
|
||||
await client.discover_services()
|
||||
service_proxy = client.get_services_by_uuid(service.uuid)[0]
|
||||
await service_proxy.discover_characteristics()
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
assert await characteristic_proxy.read_value() == b'9999'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_eatt_write():
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
devices[1].gatt_server.register_eatt()
|
||||
|
||||
write_queue = asyncio.Queue()
|
||||
characteristic = Characteristic(
|
||||
'1234',
|
||||
Characteristic.Properties.WRITE,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)),
|
||||
)
|
||||
service = Service('ABCD', [characteristic])
|
||||
devices[1].add_service(service)
|
||||
|
||||
client = await gatt_client.Client.connect_eatt(devices.connections[0])
|
||||
await client.discover_services()
|
||||
service_proxy = client.get_services_by_uuid(service.uuid)[0]
|
||||
await service_proxy.discover_characteristics()
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
await characteristic_proxy.write_value(b'9999')
|
||||
assert await write_queue.get() == (devices.connections[1], b'9999')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_eatt_notify():
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
devices[1].gatt_server.register_eatt()
|
||||
|
||||
characteristic = Characteristic(
|
||||
'1234',
|
||||
Characteristic.Properties.NOTIFY,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
)
|
||||
service = Service('ABCD', [characteristic])
|
||||
devices[1].add_service(service)
|
||||
|
||||
clients = [
|
||||
(
|
||||
devices.connections[0].gatt_client,
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
(
|
||||
await gatt_client.Client.connect_eatt(devices.connections[0]),
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
(
|
||||
await gatt_client.Client.connect_eatt(devices.connections[0]),
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
]
|
||||
for client, queue in clients:
|
||||
await client.discover_services()
|
||||
service_proxy = client.get_services_by_uuid(service.uuid)[0]
|
||||
await service_proxy.discover_characteristics()
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
|
||||
for client, queue in clients[:2]:
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True)
|
||||
|
||||
await devices[1].gatt_server.notify_subscribers(characteristic, b'1234')
|
||||
for _, queue in clients[:2]:
|
||||
assert await queue.get() == b'1234'
|
||||
assert queue.empty()
|
||||
assert clients[2][1].empty()
|
||||
|
||||
await devices[1].gatt_server.notify_subscriber(
|
||||
devices.connections[1], characteristic, b'5678'
|
||||
)
|
||||
for _, queue in clients[:2]:
|
||||
assert await queue.get() == b'5678'
|
||||
assert queue.empty()
|
||||
assert clients[2][1].empty()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_eatt_indicate():
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
devices[1].gatt_server.register_eatt()
|
||||
|
||||
characteristic = Characteristic(
|
||||
'1234',
|
||||
Characteristic.Properties.INDICATE,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
)
|
||||
service = Service('ABCD', [characteristic])
|
||||
devices[1].add_service(service)
|
||||
|
||||
clients = [
|
||||
(
|
||||
devices.connections[0].gatt_client,
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
(
|
||||
await gatt_client.Client.connect_eatt(devices.connections[0]),
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
(
|
||||
await gatt_client.Client.connect_eatt(devices.connections[0]),
|
||||
asyncio.Queue[bytes](),
|
||||
),
|
||||
]
|
||||
for client, queue in clients:
|
||||
await client.discover_services()
|
||||
service_proxy = client.get_services_by_uuid(service.uuid)[0]
|
||||
await service_proxy.discover_characteristics()
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
|
||||
for client, queue in clients[:2]:
|
||||
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
|
||||
characteristic.uuid
|
||||
)[0]
|
||||
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False)
|
||||
|
||||
await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234')
|
||||
for _, queue in clients[:2]:
|
||||
assert await queue.get() == b'1234'
|
||||
assert queue.empty()
|
||||
assert clients[2][1].empty()
|
||||
|
||||
await devices[1].gatt_server.indicate_subscriber(
|
||||
devices.connections[1], characteristic, b'5678'
|
||||
)
|
||||
for _, queue in clients[:2]:
|
||||
assert await queue.get() == b'5678'
|
||||
assert queue.empty()
|
||||
assert clients[2][1].empty()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_eatt_connection_failure():
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
|
||||
with pytest.raises(l2cap.L2capError):
|
||||
await gatt_client.Client.connect_eatt(devices.connections[0])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
|
||||
Reference in New Issue
Block a user