Add EATT Support

This commit is contained in:
Josh Wu
2025-12-18 15:23:30 +08:00
parent b4261548e8
commit df697c6513
9 changed files with 578 additions and 194 deletions

View File

@@ -298,6 +298,7 @@ class Speaker:
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
)
device_config.le_enabled = True

View File

@@ -34,10 +34,13 @@ from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
TypeAlias,
TypeVar,
)
from bumble import hci, utils
from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
_T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -58,6 +69,7 @@ _T = TypeVar('_T')
ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
@@ -780,6 +792,43 @@ class AttributeValue(Generic[_T]):
return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
@@ -855,7 +904,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, connection: Connection) -> bytes:
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -890,6 +940,17 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
@@ -897,7 +958,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Connection, value: bytes) -> None:
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
@@ -931,6 +993,15 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value

View File

@@ -41,6 +41,7 @@ from typing import (
from typing_extensions import Self
from bumble import (
att,
core,
data_types,
gatt,
@@ -53,7 +54,6 @@ from bumble import (
smp,
utils,
)
from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from bumble.colors import color
from bumble.core import (
AdvertisingData,
@@ -1743,7 +1743,6 @@ class Connection(utils.CompositeEventEmitter):
EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure"
EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update"
EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure"
EVENT_CONNECTION_ATT_MTU_UPDATE = "connection_att_mtu_update"
EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change"
EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = (
"channel_sounding_capabilities_failure"
@@ -1846,7 +1845,7 @@ class Connection(utils.CompositeEventEmitter):
self.encryption_key_size = 0
self.authenticated = False
self.sc = False
self.att_mtu = ATT_DEFAULT_MTU
self.att_mtu = att.ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = gatt_client.Client(self) # Per-connection client
self.gatt_server = (
@@ -1996,6 +1995,15 @@ class Connection(utils.CompositeEventEmitter):
self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features
def on_att_mtu_update(self, mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
f'{self.peer_address} as {self.role_name}, '
f'{mtu}'
)
self.att_mtu = mtu
self.emit(self.EVENT_CONNECTION_ATT_MTU_UPDATE)
@property
def data_packet_queue(self) -> DataPacketQueue | None:
return self.device.host.get_data_packet_queue(self.handle)
@@ -2079,6 +2087,7 @@ class DeviceConfiguration:
l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION,
l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE,
)
eatt_enabled: bool = False
def __post_init__(self) -> None:
self.gatt_services: list[dict[str, Any]] = []
@@ -2497,7 +2506,10 @@ class Device(utils.CompositeEventEmitter):
add_gap_service=config.gap_service_enabled,
add_gatt_service=config.gatt_service_enabled,
)
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu)
self.l2cap_channel_manager.register_fixed_channel(att.ATT_CID, self.on_gatt_pdu)
if self.config.eatt_enabled:
self.gatt_server.register_eatt()
# Forward some events
utils.setup_event_forwarding(
@@ -5140,7 +5152,11 @@ class Device(utils.CompositeEventEmitter):
if add_gap_service:
self.gatt_server.add_service(GenericAccessService(self.name))
if add_gatt_service:
self.gatt_service = gatt_service.GenericAttributeProfileService()
self.gatt_service = gatt_service.GenericAttributeProfileService(
gatt.ServerSupportedFeatures.EATT_SUPPORTED
if self.config.eatt_enabled
else None
)
self.gatt_server.add_service(self.gatt_service)
async def notify_subscriber(
@@ -6240,17 +6256,6 @@ class Device(utils.CompositeEventEmitter):
)
connection.emit(connection.EVENT_LE_SUBRATE_CHANGE)
@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
f'{att_mtu}'
)
connection.att_mtu = att_mtu
connection.emit(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
@host_event_handler
@with_connection_from_handle
def on_connection_data_length_change(
@@ -6437,7 +6442,7 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_handle
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu)
att_pdu = att.ATT_PDU.from_bytes(pdu)
# Conveniently, even-numbered op codes are client->server and
# odd-numbered ones are server->client

View File

@@ -31,7 +31,7 @@ import struct
from collections.abc import Iterable, Sequence
from typing import TypeVar
from bumble.att import Attribute, AttributeValue
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color
from bumble.core import UUID, BaseBumbleError
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
value_str = '<dynamic>'
else:
value_str = '<...>'

View File

@@ -26,6 +26,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import struct
from collections.abc import Callable, Iterable
@@ -35,9 +36,10 @@ from typing import (
Any,
Generic,
TypeVar,
overload,
)
from bumble import att, core, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
@@ -54,12 +56,12 @@ from bumble.gatt import (
)
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T')
# -----------------------------------------------------------------------------
@@ -267,8 +269,8 @@ class Client:
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None:
self.connection = connection
def __init__(self, bearer: att.Bearer) -> None:
self.bearer = bearer
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -278,21 +280,78 @@ class Client:
self.services = []
self.cached_values = {}
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -321,10 +380,7 @@ class Client:
def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
@@ -336,7 +392,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return self.mtu
# Send the request
self.mtu_exchange_done = True
@@ -347,9 +403,9 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
self.mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
return self.mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
@@ -942,7 +998,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -966,7 +1022,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
@@ -1062,14 +1118,13 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
def on_disconnection(self, *args) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -1099,8 +1154,7 @@ class Client:
else:
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
'red',
)
+ str(att_pdu)

View File

@@ -32,9 +32,8 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar
from bumble import att, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -44,14 +43,13 @@ from bumble.gatt import (
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
IncludedServiceDeclaration,
Service,
)
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.device import Device
# -----------------------------------------------------------------------------
# Logging
@@ -65,6 +63,18 @@ logger = logging.getLogger(__name__)
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
else:
return f'[0x{bearer.handle:04X}]'
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -72,9 +82,9 @@ class Server(utils.EventEmitter):
attributes: list[att.Attribute]
services: list[Service]
attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[int, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, asyncio.futures.Future | None]
subscribers: dict[att.Bearer, dict[int, bytes]]
indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -96,8 +106,29 @@ class Server(utils.EventEmitter):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
def register_eatt(
self, spec: l2cap.LeCreditBasedChannelSpec | None = None
) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug(
"New EATT Bearer Conenction=0x%04X CID=0x%04X",
channel.connection.handle,
channel.source_cid,
)
channel.att_mtu = att.ATT_DEFAULT_MTU
channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu)
)
return self.device.create_l2cap_server(
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
)
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
if att.is_enhanced_bearer(bearer):
bearer.write(pdu)
else:
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
def next_handle(self) -> int:
return 1 + len(self.attributes)
@@ -138,7 +169,7 @@ class Server(utils.EventEmitter):
None,
)
def get_service_attribute(self, service_uuid: UUID) -> Service | None:
def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
return next(
(
attribute
@@ -151,7 +182,7 @@ class Server(utils.EventEmitter):
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
self, service_uuid: core.UUID, characteristic_uuid: core.UUID
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
@@ -176,7 +207,10 @@ class Server(utils.EventEmitter):
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
self,
service_uuid: core.UUID,
characteristic_uuid: core.UUID,
descriptor_uuid: core.UUID,
) -> Descriptor | None:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
@@ -257,14 +291,7 @@ class Server(utils.EventEmitter):
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
att.Attribute.READABLE | att.Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
self.make_descriptor_value(characteristic),
)
)
@@ -280,10 +307,21 @@ class Server(utils.EventEmitter):
for service in services:
self.add_service(service)
def read_cccd(
self, connection: Connection, characteristic: Characteristic
) -> bytes:
subscribers = self.subscribers.get(connection.handle)
def make_descriptor_value(
self, characteristic: Characteristic
) -> att.AttributeValueV2:
# It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
return att.AttributeValueV2(
lambda bearer, characteristic=characteristic: self.read_cccd(
bearer, characteristic
),
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
bearer, characteristic, value
),
)
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
subscribers = self.subscribers.get(bearer)
cccd = None
if subscribers:
cccd = subscribers.get(characteristic.handle)
@@ -292,12 +330,12 @@ class Server(utils.EventEmitter):
def write_cccd(
self,
connection: Connection,
bearer: att.Bearer,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'Subscription update for connection={_bearer_id(bearer)}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
@@ -306,41 +344,60 @@ class Server(utils.EventEmitter):
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
cccds = self.subscribers.setdefault(bearer, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
characteristic.EVENT_SUBSCRIPTION,
connection,
bearer,
notify_enabled,
indicate_enabled,
)
self.emit(
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
connection,
bearer,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, bytes(response))
def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
self.send_gatt_pdu(bearer, bytes(response))
async def notify_subscriber(
self,
connection: Connection,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._notify_single_subscriber(bearer, attribute, value, force)
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not notifying, no subscribers')
return
@@ -356,34 +413,53 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
async def indicate_subscriber(
self,
connection: Connection,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._indicate_single_bearer(bearer, attribute, value, force)
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not indicating, no subscribers')
return
@@ -399,40 +475,38 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert self.pending_confirmations[connection.handle] is None
async with self.indication_semaphores[bearer]:
assert self.pending_confirmations[bearer] is None
# Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[connection.handle] = (
pending_confirmation = self.pending_confirmations[bearer] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(connection.handle, bytes(indication))
self.send_gatt_pdu(bearer, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[connection.handle] = None
self.pending_confirmations[bearer] = None
async def _notify_or_indicate_subscribers(
self,
@@ -441,24 +515,24 @@ class Server(utils.EventEmitter):
value: bytes | None = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
# Get all the bearers for which there's at least one subscription
bearers: list[att.Bearer] = [
bearer
for bearer, subscribers in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
if bearers:
coroutine = (
self._indicate_single_bearer
if indicate
else self._notify_single_subscriber
)
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
asyncio.create_task(coroutine(bearer, attribute, value, force))
for bearer in bearers
]
)
@@ -480,21 +554,18 @@ class Server(utils.EventEmitter):
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
del self.indication_semaphores[connection.handle]
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_disconnection(self, bearer: att.Bearer) -> None:
self.subscribers.pop(bearer, None)
self.indication_semaphores.pop(bearer, None)
self.pending_confirmations.pop(bearer, None)
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(connection, att_pdu)
handler(bearer, att_pdu)
except att.ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = att.ATT_Error_Response(
@@ -502,7 +573,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
)
self.send_response(connection, response)
self.send_response(bearer, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = att.ATT_Error_Response(
@@ -510,18 +581,18 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
raise
else:
# No specific handler registered
if att_pdu.op_code in att.ATT_REQUESTS:
# Invoke the generic handler
self.on_att_request(connection, att_pdu)
self.on_att_request(bearer, att_pdu)
else:
# Just ignore
logger.warning(
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(att_pdu)
@@ -530,13 +601,14 @@ class Server(utils.EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(pdu)
)
@@ -545,29 +617,28 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
def on_att_exchange_mtu_request(
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
bearer.on_att_mtu_update(mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(
self, connection: Connection, request: att.ATT_Find_Information_Request
self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -580,7 +651,7 @@ class Server(utils.EventEmitter):
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
@@ -590,7 +661,7 @@ class Server(utils.EventEmitter):
return
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[att.Attribute] = []
uuid_size = 0
for attribute in (
@@ -632,18 +703,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes = []
response: att.ATT_PDU
async for attribute in (
@@ -652,7 +723,7 @@ class Server(utils.EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and (await attribute.read_value(bearer)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -688,17 +759,17 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request(
self, connection: Connection, request: att.ATT_Read_By_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
response: att.ATT_PDU = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -716,7 +787,7 @@ class Server(utils.EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
except att.ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -729,7 +800,7 @@ class Server(utils.EventEmitter):
break
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 4, 253)
max_attribute_size = min(bearer.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -756,11 +827,11 @@ class Server(utils.EventEmitter):
else:
logging.debug(f"not found {request}")
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_request(
self, connection: Connection, request: att.ATT_Read_Request
self, bearer: att.Bearer, request: att.ATT_Read_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -769,7 +840,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -777,7 +848,7 @@ class Server(utils.EventEmitter):
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
value_size = min(bearer.att_mtu - 1, len(value))
response = att.ATT_Read_Response(attribute_value=value[:value_size])
else:
response = att.ATT_Error_Response(
@@ -785,11 +856,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request(
self, connection: Connection, request: att.ATT_Read_Blob_Request
self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -798,7 +869,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -812,7 +883,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
elif len(value) <= bearer.att_mtu - 1:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -820,7 +891,7 @@ class Server(utils.EventEmitter):
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
bearer.att_mtu - 1, len(value) - request.value_offset
)
response = att.ATT_Read_Blob_Response(
part_attribute_value=value[
@@ -833,11 +904,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -852,10 +923,10 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
return
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[tuple[int, int, bytes]] = []
for attribute in (
attribute
@@ -867,9 +938,9 @@ class Server(utils.EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
max_attribute_size = min(bearer.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -904,11 +975,11 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_request(
self, connection: Connection, request: att.ATT_Write_Request
self, bearer: att.Bearer, request: att.ATT_Write_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -918,7 +989,7 @@ class Server(utils.EventEmitter):
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -932,7 +1003,7 @@ class Server(utils.EventEmitter):
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -944,7 +1015,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -954,11 +1025,11 @@ class Server(utils.EventEmitter):
else:
# Done
response = att.ATT_Write_Response()
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_command(
self, connection: Connection, request: att.ATT_Write_Command
self, bearer: att.Bearer, request: att.ATT_Write_Command
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
@@ -977,22 +1048,20 @@ class Server(utils.EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except Exception:
logger.exception('!!! ignoring exception')
def on_att_handle_value_confirmation(
self,
connection: Connection,
bearer: att.Bearer,
confirmation: att.ATT_Handle_Value_Confirmation,
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
del confirmation # Unused.
if (
pending_confirmation := self.pending_confirmations[connection.handle]
) is None:
if (pending_confirmation := self.pending_confirmations[bearer]) is None:
# Not expected!
logger.warning(
'!!! unexpected confirmation, there is no pending indication'

View File

@@ -1552,6 +1552,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
EVENT_ATT_MTU_UPDATE = "att_mtu_update"
def __init__(
self,
@@ -1591,6 +1592,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.drained = asyncio.Event()
self.att_mtu = 0 # Filled by GATT client or server later.
self.drained.set()
@@ -1821,6 +1823,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_att_mtu_update(self, mtu: int) -> None:
self.att_mtu = mtu
self.emit(self.EVENT_ATT_MTU_UPDATE, mtu)
def flush_output(self) -> None:
self.out_queue.clear()
self.out_sdu = None

View File

@@ -19,10 +19,10 @@ import asyncio
import sys
import bumble.logging
from bumble import gatt_client
from bumble.colors import color
from bumble.core import ProtocolError
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.device import Connection, Device
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -34,24 +34,27 @@ class Listener(Device.Listener):
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
async def on_connection(self, connection: Connection):
print(f'=== Connected to {connection}')
# Discover all services
print('=== Discovering services')
peer = Peer(connection)
await peer.discover_services()
for service in peer.services:
if connection.device.config.eatt_enabled:
client = await gatt_client.Client.connect_eatt(connection)
else:
client = connection.gatt_client
await client.discover_services()
for service in client.services:
await service.discover_characteristics()
for characteristic in service.characteristics:
await characteristic.discover_descriptors()
print('=== Services discovered')
show_services(peer.services)
gatt_client.show_services(client.services)
# Discover all attributes
print('=== Discovering attributes')
attributes = await peer.discover_attributes()
attributes = await client.discover_attributes()
for attribute in attributes:
print(attribute)
print('=== Attributes discovered')
@@ -59,7 +62,7 @@ class Listener(Device.Listener):
# Read all attributes
for attribute in attributes:
try:
value = await peer.read_value(attribute)
value = await client.read_value(attribute)
print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green'))
except ProtocolError as error:
print(color(f'cannot read {attribute.handle:04X}:', 'red'), error)

View File

@@ -28,6 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
import pytest
from typing_extensions import Self
from bumble import gatt_client, l2cap
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
@@ -63,7 +64,6 @@ from bumble.gatt_adapters import (
UTF8CharacteristicAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy
from .test_utils import Devices, TwoDevices, async_barrier
@@ -140,7 +140,7 @@ async def test_characteristic_encoding():
await c.write_value(Mock(), bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
class FooProxy(gatt_client.CharacteristicProxy):
def __init__(self, characteristic):
super().__init__(
characteristic.client,
@@ -456,7 +456,7 @@ async def test_CharacteristicProxyAdapter() -> None:
async def write_value(self, handle, value, with_response=False):
self.value = value
class TestAttributeProxy(CharacteristicProxy):
class TestAttributeProxy(gatt_client.CharacteristicProxy):
def __init__(self, value) -> None:
super().__init__(Client(value), 0, 0, None, 0) # type: ignore
@@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid():
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0
@@ -1463,6 +1463,181 @@ async def test_write_return_error():
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_read():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.READ,
Characteristic.Permissions.READABLE,
b'9999',
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
assert await characteristic_proxy.read_value() == b'9999'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_write():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
write_queue = asyncio.Queue()
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)),
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.write_value(b'9999')
assert await write_queue.get() == (devices.connections[1], b'9999')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_notify():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True)
await devices[1].gatt_server.notify_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.notify_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_indicate():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.INDICATE,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False)
await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.indicate_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_connection_failure():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await gatt_client.Client.connect_eatt(devices.connections[0])
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())