Add EATT Support

This commit is contained in:
Josh Wu
2025-12-18 15:23:30 +08:00
parent b4261548e8
commit df697c6513
9 changed files with 578 additions and 194 deletions

View File

@@ -26,6 +26,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import struct
from collections.abc import Callable, Iterable
@@ -35,9 +36,10 @@ from typing import (
Any,
Generic,
TypeVar,
overload,
)
from bumble import att, core, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
@@ -54,12 +56,12 @@ from bumble.gatt import (
)
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T')
# -----------------------------------------------------------------------------
@@ -267,8 +269,8 @@ class Client:
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None:
self.connection = connection
def __init__(self, bearer: att.Bearer) -> None:
self.bearer = bearer
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -278,21 +280,78 @@ class Client:
self.services = []
self.cached_values = {}
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -321,10 +380,7 @@ class Client:
def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
@@ -336,7 +392,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return self.mtu
# Send the request
self.mtu_exchange_done = True
@@ -347,9 +403,9 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
self.mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
return self.mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
@@ -942,7 +998,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -966,7 +1022,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
@@ -1062,14 +1118,13 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
def on_disconnection(self, *args) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -1099,8 +1154,7 @@ class Client:
else:
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
'red',
)
+ str(att_pdu)