Merge pull request #836 from zxzxwu/eatt

Add EATT Support
This commit is contained in:
zxzxwu
2026-01-05 22:26:17 +08:00
committed by GitHub
9 changed files with 578 additions and 194 deletions

View File

@@ -298,6 +298,7 @@ class Speaker:
advertising_interval_max=25, advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'), address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS, identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
) )
device_config.le_enabled = True device_config.le_enabled = True

View File

@@ -34,10 +34,13 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
ClassVar, ClassVar,
Generic, Generic,
TypeAlias,
TypeVar, TypeVar,
) )
from bumble import hci, utils from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object from bumble.hci import HCI_Object
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
_T = TypeVar('_T') _T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -58,6 +69,7 @@ _T = TypeVar('_T')
ATT_CID = 0x04 ATT_CID = 0x04
ATT_PSM = 0x001F ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum): class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
@@ -780,6 +792,43 @@ class AttributeValue(Generic[_T]):
return self._write(connection, value) return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]): class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag): class Permissions(enum.IntFlag):
@@ -855,7 +904,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T: def decode_value(self, value: bytes) -> _T:
return value # type: ignore return value # type: ignore
async def read_value(self, connection: Connection) -> bytes: async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if ( if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION) (self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None and connection is not None
@@ -890,6 +940,17 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
) from error ) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else: else:
value = self.value value = self.value
@@ -897,7 +958,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value) return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Connection, value: bytes) -> None: async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if ( if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION) (self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None and connection is not None
@@ -931,6 +993,15 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
) from error ) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else: else:
self.value = decoded_value self.value = decoded_value

View File

@@ -41,6 +41,7 @@ from typing import (
from typing_extensions import Self from typing_extensions import Self
from bumble import ( from bumble import (
att,
core, core,
data_types, data_types,
gatt, gatt,
@@ -53,7 +54,6 @@ from bumble import (
smp, smp,
utils, utils,
) )
from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from bumble.colors import color from bumble.colors import color
from bumble.core import ( from bumble.core import (
AdvertisingData, AdvertisingData,
@@ -1743,7 +1743,6 @@ class Connection(utils.CompositeEventEmitter):
EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure" EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure"
EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update" EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update"
EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure" EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure"
EVENT_CONNECTION_ATT_MTU_UPDATE = "connection_att_mtu_update"
EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change" EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change"
EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = ( EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = (
"channel_sounding_capabilities_failure" "channel_sounding_capabilities_failure"
@@ -1846,7 +1845,7 @@ class Connection(utils.CompositeEventEmitter):
self.encryption_key_size = 0 self.encryption_key_size = 0
self.authenticated = False self.authenticated = False
self.sc = False self.sc = False
self.att_mtu = ATT_DEFAULT_MTU self.att_mtu = att.ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = gatt_client.Client(self) # Per-connection client self.gatt_client = gatt_client.Client(self) # Per-connection client
self.gatt_server = ( self.gatt_server = (
@@ -1996,6 +1995,15 @@ class Connection(utils.CompositeEventEmitter):
self.peer_le_features = await self.device.get_remote_le_features(self) self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features return self.peer_le_features
def on_att_mtu_update(self, mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
f'{self.peer_address} as {self.role_name}, '
f'{mtu}'
)
self.att_mtu = mtu
self.emit(self.EVENT_CONNECTION_ATT_MTU_UPDATE)
@property @property
def data_packet_queue(self) -> DataPacketQueue | None: def data_packet_queue(self) -> DataPacketQueue | None:
return self.device.host.get_data_packet_queue(self.handle) return self.device.host.get_data_packet_queue(self.handle)
@@ -2079,6 +2087,7 @@ class DeviceConfiguration:
l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION, l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION,
l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE, l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE,
) )
eatt_enabled: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.gatt_services: list[dict[str, Any]] = [] self.gatt_services: list[dict[str, Any]] = []
@@ -2497,7 +2506,10 @@ class Device(utils.CompositeEventEmitter):
add_gap_service=config.gap_service_enabled, add_gap_service=config.gap_service_enabled,
add_gatt_service=config.gatt_service_enabled, add_gatt_service=config.gatt_service_enabled,
) )
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu) self.l2cap_channel_manager.register_fixed_channel(att.ATT_CID, self.on_gatt_pdu)
if self.config.eatt_enabled:
self.gatt_server.register_eatt()
# Forward some events # Forward some events
utils.setup_event_forwarding( utils.setup_event_forwarding(
@@ -5140,7 +5152,11 @@ class Device(utils.CompositeEventEmitter):
if add_gap_service: if add_gap_service:
self.gatt_server.add_service(GenericAccessService(self.name)) self.gatt_server.add_service(GenericAccessService(self.name))
if add_gatt_service: if add_gatt_service:
self.gatt_service = gatt_service.GenericAttributeProfileService() self.gatt_service = gatt_service.GenericAttributeProfileService(
gatt.ServerSupportedFeatures.EATT_SUPPORTED
if self.config.eatt_enabled
else None
)
self.gatt_server.add_service(self.gatt_service) self.gatt_server.add_service(self.gatt_service)
async def notify_subscriber( async def notify_subscriber(
@@ -6240,17 +6256,6 @@ class Device(utils.CompositeEventEmitter):
) )
connection.emit(connection.EVENT_LE_SUBRATE_CHANGE) connection.emit(connection.EVENT_LE_SUBRATE_CHANGE)
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
f'{att_mtu}'
)
connection.att_mtu = att_mtu
connection.emit(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle
def on_connection_data_length_change( def on_connection_data_length_change(
@@ -6437,7 +6442,7 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_handle @with_connection_from_handle
def on_gatt_pdu(self, connection: Connection, pdu: bytes): def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object # Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu) att_pdu = att.ATT_PDU.from_bytes(pdu)
# Conveniently, even-numbered op codes are client->server and # Conveniently, even-numbered op codes are client->server and
# odd-numbered ones are server->client # odd-numbered ones are server->client

View File

@@ -31,7 +31,7 @@ import struct
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import TypeVar from typing import TypeVar
from bumble.att import Attribute, AttributeValue from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID, BaseBumbleError from bumble.core import UUID, BaseBumbleError
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str: def __str__(self) -> str:
if isinstance(self.value, bytes): if isinstance(self.value, bytes):
value_str = self.value.hex() value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue): elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
value_str = '<dynamic>' value_str = '<dynamic>'
else: else:
value_str = '<...>' value_str = '<...>'

View File

@@ -26,6 +26,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import logging import logging
import struct import struct
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
@@ -35,9 +36,10 @@ from typing import (
Any, Any,
Generic, Generic,
TypeVar, TypeVar,
overload,
) )
from bumble import att, core, utils from bumble import att, core, l2cap, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID, InvalidStateError from bumble.core import UUID, InvalidStateError
from bumble.gatt import ( from bumble.gatt import (
@@ -54,12 +56,12 @@ from bumble.gatt import (
) )
from bumble.hci import HCI_Constant from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Typing # Typing
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T') _T = TypeVar('_T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -267,8 +269,8 @@ class Client:
pending_response: asyncio.futures.Future[att.ATT_PDU] | None pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None: def __init__(self, bearer: att.Bearer) -> None:
self.connection = connection self.bearer = bearer
self.mtu_exchange_done = False self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None self.pending_request = None
@@ -278,21 +280,78 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection) if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None: def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(att.ATT_CID, pdu) if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None: async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug( logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(bytes(command)) self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU): async def send_request(self, request: att.ATT_PDU):
logger.debug( logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection) # Wait until we can send (only one pending command at a time for the connection)
response = None response = None
@@ -321,10 +380,7 @@ class Client:
def send_confirmation( def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None: ) -> None:
logger.debug( logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(bytes(confirmation)) self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
@@ -336,7 +392,7 @@ class Client:
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:
return self.connection.att_mtu return self.mtu
# Send the request # Send the request
self.mtu_exchange_done = True self.mtu_exchange_done = True
@@ -347,9 +403,9 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response) raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU # Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu) self.mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu return self.mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]: def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
@@ -942,7 +998,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller # If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that # specifically asked not to do that
attribute_value = response.attribute_value attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1: if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value') logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value) offset = len(attribute_value)
while True: while True:
@@ -966,7 +1022,7 @@ class Client:
part = response.part_attribute_value part = response.part_attribute_value
attribute_value += part attribute_value += part
if len(part) < self.connection.att_mtu - 1: if len(part) < self.mtu - 1:
break break
offset += len(part) offset += len(part)
@@ -1062,14 +1118,13 @@ class Client:
) )
) )
def on_disconnection(self, _) -> None: def on_disconnection(self, *args) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done(): if self.pending_response and not self.pending_response.done():
self.pending_response.cancel() self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None: def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug( logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in att.ATT_RESPONSES: if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None: if self.pending_request is None:
# Not expected! # Not expected!
@@ -1099,8 +1154,7 @@ class Client:
else: else:
logger.warning( logger.warning(
color( color(
'--- Ignoring GATT Response from ' '--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
f'[0x{self.connection.handle:04X}]: ',
'red', 'red',
) )
+ str(att_pdu) + str(att_pdu)

View File

@@ -32,9 +32,8 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
from bumble import att, utils from bumble import att, core, l2cap, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID
from bumble.gatt import ( from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -44,14 +43,13 @@ from bumble.gatt import (
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
CharacteristicDeclaration, CharacteristicDeclaration,
CharacteristicValue,
Descriptor, Descriptor,
IncludedServiceDeclaration, IncludedServiceDeclaration,
Service, Service,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection, Device from bumble.device import Device
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -65,6 +63,18 @@ logger = logging.getLogger(__name__)
GATT_SERVER_DEFAULT_MAX_MTU = 517 GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
else:
return f'[0x{bearer.handle:04X}]'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GATT Server # GATT Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -72,9 +82,9 @@ class Server(utils.EventEmitter):
attributes: list[att.Attribute] attributes: list[att.Attribute]
services: list[Service] services: list[Service]
attributes_by_handle: dict[int, att.Attribute] attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[int, dict[int, bytes]] subscribers: dict[att.Bearer, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore] indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
pending_confirmations: defaultdict[int, asyncio.futures.Future | None] pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription" EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -96,8 +106,29 @@ class Server(utils.EventEmitter):
def __str__(self) -> str: def __str__(self) -> str:
return "\n".join(map(str, self.attributes)) return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None: def register_eatt(
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu) self, spec: l2cap.LeCreditBasedChannelSpec | None = None
) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug(
"New EATT Bearer Conenction=0x%04X CID=0x%04X",
channel.connection.handle,
channel.source_cid,
)
channel.att_mtu = att.ATT_DEFAULT_MTU
channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu)
)
return self.device.create_l2cap_server(
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
)
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
if att.is_enhanced_bearer(bearer):
bearer.write(pdu)
else:
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
def next_handle(self) -> int: def next_handle(self) -> int:
return 1 + len(self.attributes) return 1 + len(self.attributes)
@@ -138,7 +169,7 @@ class Server(utils.EventEmitter):
None, None,
) )
def get_service_attribute(self, service_uuid: UUID) -> Service | None: def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
return next( return next(
( (
attribute attribute
@@ -151,7 +182,7 @@ class Server(utils.EventEmitter):
) )
def get_characteristic_attributes( def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID self, service_uuid: core.UUID, characteristic_uuid: core.UUID
) -> tuple[CharacteristicDeclaration, Characteristic] | None: ) -> tuple[CharacteristicDeclaration, Characteristic] | None:
service_handle = self.get_service_attribute(service_uuid) service_handle = self.get_service_attribute(service_uuid)
if not service_handle: if not service_handle:
@@ -176,7 +207,10 @@ class Server(utils.EventEmitter):
) )
def get_descriptor_attribute( def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID self,
service_uuid: core.UUID,
characteristic_uuid: core.UUID,
descriptor_uuid: core.UUID,
) -> Descriptor | None: ) -> Descriptor | None:
characteristics = self.get_characteristic_attributes( characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid service_uuid, characteristic_uuid
@@ -257,14 +291,7 @@ class Server(utils.EventEmitter):
Descriptor( Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
att.Attribute.READABLE | att.Attribute.WRITEABLE, att.Attribute.READABLE | att.Attribute.WRITEABLE,
CharacteristicValue( self.make_descriptor_value(characteristic),
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
) )
) )
@@ -280,10 +307,21 @@ class Server(utils.EventEmitter):
for service in services: for service in services:
self.add_service(service) self.add_service(service)
def read_cccd( def make_descriptor_value(
self, connection: Connection, characteristic: Characteristic self, characteristic: Characteristic
) -> bytes: ) -> att.AttributeValueV2:
subscribers = self.subscribers.get(connection.handle) # It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
return att.AttributeValueV2(
lambda bearer, characteristic=characteristic: self.read_cccd(
bearer, characteristic
),
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
bearer, characteristic, value
),
)
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
subscribers = self.subscribers.get(bearer)
cccd = None cccd = None
if subscribers: if subscribers:
cccd = subscribers.get(characteristic.handle) cccd = subscribers.get(characteristic.handle)
@@ -292,12 +330,12 @@ class Server(utils.EventEmitter):
def write_cccd( def write_cccd(
self, self,
connection: Connection, bearer: att.Bearer,
characteristic: Characteristic, characteristic: Characteristic,
value: bytes, value: bytes,
) -> None: ) -> None:
logger.debug( logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, ' f'Subscription update for connection={_bearer_id(bearer)}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}' f'handle=0x{characteristic.handle:04X}: {value.hex()}'
) )
@@ -306,41 +344,60 @@ class Server(utils.EventEmitter):
logger.warning('CCCD value not 2 bytes long') logger.warning('CCCD value not 2 bytes long')
return return
cccds = self.subscribers.setdefault(connection.handle, {}) cccds = self.subscribers.setdefault(bearer, {})
cccds[characteristic.handle] = value cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}') logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0 notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0 indicate_enabled = value[0] & 0x02 != 0
characteristic.emit( characteristic.emit(
characteristic.EVENT_SUBSCRIPTION, characteristic.EVENT_SUBSCRIPTION,
connection, bearer,
notify_enabled, notify_enabled,
indicate_enabled, indicate_enabled,
) )
self.emit( self.emit(
self.EVENT_CHARACTERISTIC_SUBSCRIPTION, self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
connection, bearer,
characteristic, characteristic,
notify_enabled, notify_enabled,
indicate_enabled, indicate_enabled,
) )
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None: def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
logger.debug( logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
f'GATT Response from server: [0x{connection.handle:04X}] {response}' self.send_gatt_pdu(bearer, bytes(response))
)
self.send_gatt_pdu(connection.handle, bytes(response))
async def notify_subscriber( async def notify_subscriber(
self, self,
connection: Connection, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute,
value: bytes | None = None, value: bytes | None = None,
force: bool = False, force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._notify_single_subscriber(bearer, attribute, value, force)
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None: ) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(bearer)
if not subscribers: if not subscribers:
logger.debug('not notifying, no subscribers') logger.debug('not notifying, no subscribers')
return return
@@ -356,34 +413,53 @@ class Server(utils.EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
await attribute.read_value(connection) await attribute.read_value(bearer)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
# Truncate if needed # Truncate if needed
if len(value) > connection.att_mtu - 3: if len(value) > bearer.att_mtu - 3:
value = value[: connection.att_mtu - 3] value = value[: bearer.att_mtu - 3]
# Notify # Notify
notification = att.ATT_Handle_Value_Notification( notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value attribute_handle=attribute.handle, attribute_value=value
) )
logger.debug( logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}' self.send_gatt_pdu(bearer, bytes(notification))
)
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber( async def indicate_subscriber(
self, self,
connection: Connection, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute,
value: bytes | None = None, value: bytes | None = None,
force: bool = False, force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._indicate_single_bearer(bearer, attribute, value, force)
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None: ) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(bearer)
if not subscribers: if not subscribers:
logger.debug('not indicating, no subscribers') logger.debug('not indicating, no subscribers')
return return
@@ -399,40 +475,38 @@ class Server(utils.EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
await attribute.read_value(connection) await attribute.read_value(bearer)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
# Truncate if needed # Truncate if needed
if len(value) > connection.att_mtu - 3: if len(value) > bearer.att_mtu - 3:
value = value[: connection.att_mtu - 3] value = value[: bearer.att_mtu - 3]
# Indicate # Indicate
indication = att.ATT_Handle_Value_Indication( indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value attribute_handle=attribute.handle, attribute_value=value
) )
logger.debug( logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
# Wait until we can send (only one pending indication at a time per connection) # Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]: async with self.indication_semaphores[bearer]:
assert self.pending_confirmations[connection.handle] is None assert self.pending_confirmations[bearer] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[connection.handle] = ( pending_confirmation = self.pending_confirmations[bearer] = (
asyncio.get_running_loop().create_future() asyncio.get_running_loop().create_future()
) )
try: try:
self.send_gatt_pdu(connection.handle, bytes(indication)) self.send_gatt_pdu(bearer, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally: finally:
self.pending_confirmations[connection.handle] = None self.pending_confirmations[bearer] = None
async def _notify_or_indicate_subscribers( async def _notify_or_indicate_subscribers(
self, self,
@@ -441,24 +515,24 @@ class Server(utils.EventEmitter):
value: bytes | None = None, value: bytes | None = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
# Get all the connections for which there's at least one subscription # Get all the bearers for which there's at least one subscription
connections = [ bearers: list[att.Bearer] = [
connection bearer
for connection in [ for bearer, subscribers in self.subscribers.items()
self.device.lookup_connection(connection_handle) if force or subscribers.get(attribute.handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
] ]
# Indicate or notify for each connection # Indicate or notify for each connection
if connections: if bearers:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber coroutine = (
self._indicate_single_bearer
if indicate
else self._notify_single_subscriber
)
await asyncio.wait( await asyncio.wait(
[ [
asyncio.create_task(coroutine(connection, attribute, value, force)) asyncio.create_task(coroutine(bearer, attribute, value, force))
for connection in connections for bearer in bearers
] ]
) )
@@ -480,21 +554,18 @@ class Server(utils.EventEmitter):
): ):
return await self._notify_or_indicate_subscribers(True, attribute, value, force) return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection: Connection) -> None: def on_disconnection(self, bearer: att.Bearer) -> None:
if connection.handle in self.subscribers: self.subscribers.pop(bearer, None)
del self.subscribers[connection.handle] self.indication_semaphores.pop(bearer, None)
if connection.handle in self.indication_semaphores: self.pending_confirmations.pop(bearer, None)
del self.indication_semaphores[connection.handle]
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}' handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
if handler is not None: if handler is not None:
try: try:
handler(connection, att_pdu) handler(bearer, att_pdu)
except att.ATT_Error as error: except att.ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}') logger.debug(f'normal exception returned by handler: {error}')
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
@@ -502,7 +573,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=error.att_handle, attribute_handle_in_error=error.att_handle,
error_code=error.error_code, error_code=error.error_code,
) )
self.send_response(connection, response) self.send_response(bearer, response)
except Exception: except Exception:
logger.exception(color("!!! Exception in handler:", "red")) logger.exception(color("!!! Exception in handler:", "red"))
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
@@ -510,18 +581,18 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000, attribute_handle_in_error=0x0000,
error_code=att.ATT_UNLIKELY_ERROR_ERROR, error_code=att.ATT_UNLIKELY_ERROR_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
raise raise
else: else:
# No specific handler registered # No specific handler registered
if att_pdu.op_code in att.ATT_REQUESTS: if att_pdu.op_code in att.ATT_REQUESTS:
# Invoke the generic handler # Invoke the generic handler
self.on_att_request(connection, att_pdu) self.on_att_request(bearer, att_pdu)
else: else:
# Just ignore # Just ignore
logger.warning( logger.warning(
color( color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ', f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
'red', 'red',
) )
+ str(att_pdu) + str(att_pdu)
@@ -530,13 +601,14 @@ class Server(utils.EventEmitter):
####################################################### #######################################################
# ATT handlers # ATT handlers
####################################################### #######################################################
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None: def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
''' '''
Handler for requests without a more specific handler Handler for requests without a more specific handler
''' '''
logger.warning( logger.warning(
color( color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red' f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
'red',
) )
+ str(pdu) + str(pdu)
) )
@@ -545,29 +617,28 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000, attribute_handle_in_error=0x0000,
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR, error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
def on_att_exchange_mtu_request( def on_att_exchange_mtu_request(
self, connection: Connection, request: att.ATT_Exchange_MTU_Request self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
''' '''
self.send_response( self.send_response(
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
) )
# Compute the final MTU # Compute the final MTU
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU: if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu) mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device bearer.on_att_mtu_update(mtu)
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else: else:
logger.warning('invalid client_rx_mtu received, MTU not changed') logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request( def on_att_find_information_request(
self, connection: Connection, request: att.ATT_Find_Information_Request self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -580,7 +651,7 @@ class Server(utils.EventEmitter):
or request.starting_handle > request.ending_handle or request.starting_handle > request.ending_handle
): ):
self.send_response( self.send_response(
connection, bearer,
att.ATT_Error_Response( att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
@@ -590,7 +661,7 @@ class Server(utils.EventEmitter):
return return
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = connection.att_mtu - 2 pdu_space_available = bearer.att_mtu - 2
attributes: list[att.Attribute] = [] attributes: list[att.Attribute] = []
uuid_size = 0 uuid_size = 0
for attribute in ( for attribute in (
@@ -632,18 +703,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request( async def on_att_find_by_type_value_request(
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
''' '''
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = connection.att_mtu - 2 pdu_space_available = bearer.att_mtu - 2
attributes = [] attributes = []
response: att.ATT_PDU response: att.ATT_PDU
async for attribute in ( async for attribute in (
@@ -652,7 +723,7 @@ class Server(utils.EventEmitter):
if attribute.handle >= request.starting_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value and (await attribute.read_value(bearer)) == request.attribute_value
and pdu_space_available >= 4 and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -688,17 +759,17 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request( async def on_att_read_by_type_request(
self, connection: Connection, request: att.ATT_Read_By_Type_Request self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
pdu_space_available = connection.att_mtu - 2 pdu_space_available = bearer.att_mtu - 2
response: att.ATT_PDU = att.ATT_Error_Response( response: att.ATT_PDU = att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -716,7 +787,7 @@ class Server(utils.EventEmitter):
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = await attribute.read_value(connection) attribute_value = await attribute.read_value(bearer)
except att.ATT_Error as error: except att.ATT_Error as error:
# If the first attribute is unreadable, return an error # If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point # Otherwise return attributes up to this point
@@ -729,7 +800,7 @@ class Server(utils.EventEmitter):
break break
# Check the attribute value size # Check the attribute value size
max_attribute_size = min(connection.att_mtu - 4, 253) max_attribute_size = min(bearer.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
@@ -756,11 +827,11 @@ class Server(utils.EventEmitter):
else: else:
logging.debug(f"not found {request}") logging.debug(f"not found {request}")
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_request( async def on_att_read_request(
self, connection: Connection, request: att.ATT_Read_Request self, bearer: att.Bearer, request: att.ATT_Read_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -769,7 +840,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = await attribute.read_value(connection) value = await attribute.read_value(bearer)
except att.ATT_Error as error: except att.ATT_Error as error:
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -777,7 +848,7 @@ class Server(utils.EventEmitter):
error_code=error.error_code, error_code=error.error_code,
) )
else: else:
value_size = min(connection.att_mtu - 1, len(value)) value_size = min(bearer.att_mtu - 1, len(value))
response = att.ATT_Read_Response(attribute_value=value[:value_size]) response = att.ATT_Read_Response(attribute_value=value[:value_size])
else: else:
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
@@ -785,11 +856,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR, error_code=att.ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request( async def on_att_read_blob_request(
self, connection: Connection, request: att.ATT_Read_Blob_Request self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -798,7 +869,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = await attribute.read_value(connection) value = await attribute.read_value(bearer)
except att.ATT_Error as error: except att.ATT_Error as error:
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -812,7 +883,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_OFFSET_ERROR, error_code=att.ATT_INVALID_OFFSET_ERROR,
) )
elif len(value) <= connection.att_mtu - 1: elif len(value) <= bearer.att_mtu - 1:
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
@@ -820,7 +891,7 @@ class Server(utils.EventEmitter):
) )
else: else:
part_size = min( part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset bearer.att_mtu - 1, len(value) - request.value_offset
) )
response = att.ATT_Read_Blob_Response( response = att.ATT_Read_Blob_Response(
part_attribute_value=value[ part_attribute_value=value[
@@ -833,11 +904,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR, error_code=att.ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request( async def on_att_read_by_group_type_request(
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -852,10 +923,10 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR, error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
return return
pdu_space_available = connection.att_mtu - 2 pdu_space_available = bearer.att_mtu - 2
attributes: list[tuple[int, int, bytes]] = [] attributes: list[tuple[int, int, bytes]] = []
for attribute in ( for attribute in (
attribute attribute
@@ -867,9 +938,9 @@ class Server(utils.EventEmitter):
): ):
# No need to catch permission errors here, since these attributes # No need to catch permission errors here, since these attributes
# must all be world-readable # must all be world-readable
attribute_value = await attribute.read_value(connection) attribute_value = await attribute.read_value(bearer)
# Check the attribute value size # Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251) max_attribute_size = min(bearer.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
@@ -904,11 +975,11 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_write_request( async def on_att_write_request(
self, connection: Connection, request: att.ATT_Write_Request self, bearer: att.Bearer, request: att.ATT_Write_Request
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -918,7 +989,7 @@ class Server(utils.EventEmitter):
attribute = self.get_attribute(request.attribute_handle) attribute = self.get_attribute(request.attribute_handle)
if attribute is None: if attribute is None:
self.send_response( self.send_response(
connection, bearer,
att.ATT_Error_Response( att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
@@ -932,7 +1003,7 @@ class Server(utils.EventEmitter):
# Check the request parameters # Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE: if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response( self.send_response(
connection, bearer,
att.ATT_Error_Response( att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
@@ -944,7 +1015,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU response: att.ATT_PDU
try: try:
# Accept the value # Accept the value
await attribute.write_value(connection, request.attribute_value) await attribute.write_value(bearer, request.attribute_value)
except att.ATT_Error as error: except att.ATT_Error as error:
response = att.ATT_Error_Response( response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -954,11 +1025,11 @@ class Server(utils.EventEmitter):
else: else:
# Done # Done
response = att.ATT_Write_Response() response = att.ATT_Write_Response()
self.send_response(connection, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_write_command( async def on_att_write_command(
self, connection: Connection, request: att.ATT_Write_Command self, bearer: att.Bearer, request: att.ATT_Write_Command
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
@@ -977,22 +1048,20 @@ class Server(utils.EventEmitter):
# Accept the value # Accept the value
try: try:
await attribute.write_value(connection, request.attribute_value) await attribute.write_value(bearer, request.attribute_value)
except Exception: except Exception:
logger.exception('!!! ignoring exception') logger.exception('!!! ignoring exception')
def on_att_handle_value_confirmation( def on_att_handle_value_confirmation(
self, self,
connection: Connection, bearer: att.Bearer,
confirmation: att.ATT_Handle_Value_Confirmation, confirmation: att.ATT_Handle_Value_Confirmation,
): ):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
''' '''
del confirmation # Unused. del confirmation # Unused.
if ( if (pending_confirmation := self.pending_confirmations[bearer]) is None:
pending_confirmation := self.pending_confirmations[connection.handle]
) is None:
# Not expected! # Not expected!
logger.warning( logger.warning(
'!!! unexpected confirmation, there is no pending indication' '!!! unexpected confirmation, there is no pending indication'

View File

@@ -1552,6 +1552,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
EVENT_OPEN = "open" EVENT_OPEN = "open"
EVENT_CLOSE = "close" EVENT_CLOSE = "close"
EVENT_ATT_MTU_UPDATE = "att_mtu_update"
def __init__( def __init__(
self, self,
@@ -1591,6 +1592,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.drained = asyncio.Event() self.drained = asyncio.Event()
self.att_mtu = 0 # Filled by GATT client or server later.
self.drained.set() self.drained.set()
@@ -1821,6 +1823,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
def on_att_mtu_update(self, mtu: int) -> None:
self.att_mtu = mtu
self.emit(self.EVENT_ATT_MTU_UPDATE, mtu)
def flush_output(self) -> None: def flush_output(self) -> None:
self.out_queue.clear() self.out_queue.clear()
self.out_sdu = None self.out_sdu = None

View File

@@ -19,10 +19,10 @@ import asyncio
import sys import sys
import bumble.logging import bumble.logging
from bumble import gatt_client
from bumble.colors import color from bumble.colors import color
from bumble.core import ProtocolError from bumble.core import ProtocolError
from bumble.device import Device, Peer from bumble.device import Connection, Device
from bumble.gatt import show_services
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
@@ -34,24 +34,27 @@ class Listener(Device.Listener):
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method # pylint: disable=invalid-overridden-method
async def on_connection(self, connection): async def on_connection(self, connection: Connection):
print(f'=== Connected to {connection}') print(f'=== Connected to {connection}')
# Discover all services # Discover all services
print('=== Discovering services') print('=== Discovering services')
peer = Peer(connection) if connection.device.config.eatt_enabled:
await peer.discover_services() client = await gatt_client.Client.connect_eatt(connection)
for service in peer.services: else:
client = connection.gatt_client
await client.discover_services()
for service in client.services:
await service.discover_characteristics() await service.discover_characteristics()
for characteristic in service.characteristics: for characteristic in service.characteristics:
await characteristic.discover_descriptors() await characteristic.discover_descriptors()
print('=== Services discovered') print('=== Services discovered')
show_services(peer.services) gatt_client.show_services(client.services)
# Discover all attributes # Discover all attributes
print('=== Discovering attributes') print('=== Discovering attributes')
attributes = await peer.discover_attributes() attributes = await client.discover_attributes()
for attribute in attributes: for attribute in attributes:
print(attribute) print(attribute)
print('=== Attributes discovered') print('=== Attributes discovered')
@@ -59,7 +62,7 @@ class Listener(Device.Listener):
# Read all attributes # Read all attributes
for attribute in attributes: for attribute in attributes:
try: try:
value = await peer.read_value(attribute) value = await client.read_value(attribute)
print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green')) print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green'))
except ProtocolError as error: except ProtocolError as error:
print(color(f'cannot read {attribute.handle:04X}:', 'red'), error) print(color(f'cannot read {attribute.handle:04X}:', 'red'), error)

View File

@@ -28,6 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
import pytest import pytest
from typing_extensions import Self from typing_extensions import Self
from bumble import gatt_client, l2cap
from bumble.att import ( from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU, ATT_PDU,
@@ -63,7 +64,6 @@ from bumble.gatt_adapters import (
UTF8CharacteristicAdapter, UTF8CharacteristicAdapter,
UTF8CharacteristicProxyAdapter, UTF8CharacteristicProxyAdapter,
) )
from bumble.gatt_client import CharacteristicProxy
from .test_utils import Devices, TwoDevices, async_barrier from .test_utils import Devices, TwoDevices, async_barrier
@@ -140,7 +140,7 @@ async def test_characteristic_encoding():
await c.write_value(Mock(), bytes([122])) await c.write_value(Mock(), bytes([122]))
assert c.value == 122 assert c.value == 122
class FooProxy(CharacteristicProxy): class FooProxy(gatt_client.CharacteristicProxy):
def __init__(self, characteristic): def __init__(self, characteristic):
super().__init__( super().__init__(
characteristic.client, characteristic.client,
@@ -456,7 +456,7 @@ async def test_CharacteristicProxyAdapter() -> None:
async def write_value(self, handle, value, with_response=False): async def write_value(self, handle, value, with_response=False):
self.value = value self.value = value
class TestAttributeProxy(CharacteristicProxy): class TestAttributeProxy(gatt_client.CharacteristicProxy):
def __init__(self, value) -> None: def __init__(self, value) -> None:
super().__init__(Client(value), 0, 0, None, 0) # type: ignore super().__init__(Client(value), 0, 0, None, 0) # type: ignore
@@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid():
await peer.discover_characteristics() await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234')) c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2 assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy) assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD')) c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1 assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy) assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA')) c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0 assert len(c) == 0
@@ -1463,6 +1463,181 @@ async def test_write_return_error():
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_read():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.READ,
Characteristic.Permissions.READABLE,
b'9999',
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
assert await characteristic_proxy.read_value() == b'9999'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_write():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
write_queue = asyncio.Queue()
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)),
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.write_value(b'9999')
assert await write_queue.get() == (devices.connections[1], b'9999')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_notify():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True)
await devices[1].gatt_server.notify_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.notify_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_indicate():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.INDICATE,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False)
await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.indicate_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_connection_failure():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await gatt_client.Client.connect_eatt(devices.connections[0])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())