Add EATT Support

This commit is contained in:
Josh Wu
2025-12-18 15:23:30 +08:00
parent b4261548e8
commit df697c6513
9 changed files with 578 additions and 194 deletions

View File

@@ -34,10 +34,13 @@ from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
TypeAlias,
TypeVar,
)
from bumble import hci, utils
from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
_T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -58,6 +69,7 @@ _T = TypeVar('_T')
ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
@@ -780,6 +792,43 @@ class AttributeValue(Generic[_T]):
return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
@@ -855,7 +904,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, connection: Connection) -> bytes:
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -890,6 +940,17 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
@@ -897,7 +958,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Connection, value: bytes) -> None:
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
@@ -931,6 +993,15 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value