support async read/write for characteristic values

This commit is contained in:
Gilles Boccon-Gibod
2023-12-27 11:52:22 -08:00
parent 5d83deffa4
commit f2925ca647
10 changed files with 234 additions and 129 deletions

View File

@@ -777,7 +777,7 @@ class ConsoleApp:
if not service: if not service:
continue continue
values = [ values = [
attribute.read_value(connection) await attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic: if not characteristic:
continue continue
values = [ values = [
attribute.read_value(connection) await attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
values = [attribute.read_value(None)] values = [await attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string # TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers # send data to any subscribers
if isinstance(attribute, Characteristic): if isinstance(attribute, Characteristic):
attribute.write_value(None, value) await attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY): if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute) await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE): if attribute.has_properties(Characteristic.INDICATE):

View File

@@ -25,9 +25,21 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import functools import functools
import inspect
import struct import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value from bumble.hci import HCI_Object, key_with_value
@@ -722,12 +734,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConnectionValue(Protocol): class AttributeValue:
def read(self, connection) -> bytes: '''
... Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def write(self, connection, value: bytes) -> None: def __init__(
... self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -770,13 +808,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[str, bytes, ConnectionValue] value: Union[bytes, AttributeValue]
def __init__( def __init__(
self, self,
attribute_type: Union[str, bytes, UUID], attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'', value: Union[str, bytes, AttributeValue] = b'',
) -> None: ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
@@ -806,7 +844,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any: def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def read_value(self, connection: Optional[Connection]) -> bytes: async def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION) (self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None and connection is not None
@@ -832,6 +870,8 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'): if hasattr(self.value, 'read'):
try: try:
value = self.value.read(connection) value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -841,7 +881,7 @@ class Attribute(EventEmitter):
return self.encode_value(value) return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes: bytes) -> None: async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption: ) and not connection.encryption:
@@ -864,7 +904,9 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'): if hasattr(self.value, 'write'):
try: try:
self.value.write(connection, value) # pylint: disable=not-callable result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle

View File

@@ -23,16 +23,28 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import enum import enum
import functools import functools
import logging import logging
import struct import struct
from typing import Optional, Sequence, Iterable, List, Union from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from .colors import color from bumble.colors import color
from .core import UUID, get_dict_key_by_value from bumble.core import UUID
from .att import Attribute from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -522,56 +534,43 @@ class CharacteristicDeclaration(Attribute):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue: class CharacteristicValue(AttributeValue):
''' """Same as AttributeValue, for backward compatibility"""
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicAdapter: class CharacteristicAdapter:
''' '''
An adapter that can adapt any object with `read_value` and `write_value` An adapter that can adapt Characteristic and AttributeProxy objects
methods (like Characteristic and CharacteristicProxy objects) by wrapping by wrapping their `read_value()` and `write_value()` methods with ones that
those methods with ones that return/accept encoded/decoded values. return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to For proxies (i.e used by a GATT client), the adaptation is one where the return
`write_value` is encoded. Other objects are considered local characteristics value of `read_value()` is decoded and the value passed to `write_value()` is
so the adaptation is one where the return value of `read_value` is encoded encoded. The `subscribe()` method, is wrapped with one where the values are decoded
and the value passed to `write_value` is decoded. before being passed to the subscriber.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber. For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
''' '''
def __init__(self, characteristic): read_value: Callable
self.wrapped_characteristic = characteristic write_value: Callable
self.subscribers = {} # Map from subscriber to proxy subscriber
if asyncio.iscoroutinefunction( def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
characteristic.read_value self.wrapped_characteristic = characteristic
) and asyncio.iscoroutinefunction(characteristic.write_value): self.subscribers: Dict[
self.read_value = self.read_decoded_value Callable, Callable
self.write_value = self.write_decoded_value ] = {} # Map from subscriber to proxy subscriber
else:
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value self.write_value = self.write_encoded_value
else:
if hasattr(self.wrapped_characteristic, 'subscribe'): self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name): def __getattr__(self, name):
@@ -590,11 +589,13 @@ class CharacteristicAdapter:
else: else:
setattr(self.wrapped_characteristic, name, value) setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection): async def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection)) return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def write_encoded_value(self, connection, value): async def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value( return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value) connection, self.decode_value(value)
) )
@@ -729,13 +730,24 @@ class Descriptor(Attribute):
''' '''
def __str__(self) -> str: def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return ( return (
f'Descriptor(handle=0x{self.handle:04X}, ' f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, ' f'type={self.type}, '
f'value={self.read_value(None).hex()})' f'value={value_str})'
) )
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag): class ClientCharacteristicConfigurationBits(enum.IntFlag):
''' '''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit

View File

@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from bumble.colors import color
from .core import UUID from bumble.core import UUID
from .att import ( from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR, ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID, ATT_CID,
@@ -60,7 +60,7 @@ from .att import (
ATT_Write_Response, ATT_Write_Response,
Attribute, Attribute,
) )
from .gatt import ( from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE, GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,6 +74,7 @@ from .gatt import (
Descriptor, Descriptor,
Service, Service,
) )
from bumble.utils import AsyncRunner
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -379,7 +380,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
attribute.read_value(connection) await attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -422,7 +423,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
attribute.read_value(connection) await attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -650,7 +651,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_find_by_type_value_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
''' '''
@@ -658,13 +660,13 @@ class Server(EventEmitter):
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( async for attribute in (
attribute attribute
for attribute in self.attributes for attribute in self.attributes
if attribute.handle >= request.starting_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4 and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -702,7 +704,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_by_type_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
@@ -725,7 +728,7 @@ class Server(EventEmitter):
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = attribute.read_value(connection) attribute_value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
# If the first attribute is unreadable, return an error # If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point # Otherwise return attributes up to this point
@@ -767,14 +770,15 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = attribute.read_value(connection) value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -792,14 +796,15 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_blob_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = attribute.read_value(connection) value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -836,7 +841,8 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_by_group_type_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
''' '''
@@ -864,7 +870,7 @@ class Server(EventEmitter):
): ):
# No need to catch permission errors here, since these attributes # No need to catch permission errors here, since these attributes
# must all be world-readable # must all be world-readable
attribute_value = attribute.read_value(connection) attribute_value = await attribute.read_value(connection)
# Check the attribute value size # Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251) max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
@@ -903,7 +909,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_write_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
''' '''
@@ -936,12 +943,13 @@ class Server(EventEmitter):
return return
# Accept the value # Accept the value
attribute.write_value(connection, request.attribute_value) await attribute.write_value(connection, request.attribute_value)
# Done # Done
self.send_response(connection, ATT_Write_Response()) self.send_response(connection, ATT_Write_Response())
def on_att_write_command(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
''' '''
@@ -959,7 +967,7 @@ class Server(EventEmitter):
# Accept the value # Accept the value
try: try:
attribute.write_value(connection, request.attribute_value) await attribute.write_value(connection, request.attribute_value)
except Exception as error: except Exception as error:
logger.exception(f'!!! ignoring exception: {error}') logger.exception(f'!!! ignoring exception: {error}')

View File

@@ -18,7 +18,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
import logging import logging
from typing import List from typing import List, Optional
from bumble import l2cap from bumble import l2cap
from ..core import AdvertisingData from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0]) self.emit('volume', connection, value[0])
# Handler for audio control commands # Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value): def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}') logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0] opcode = value[0]
if opcode == AshaService.OPCODE_START: if opcode == AshaService.OPCODE_START:

View File

@@ -114,7 +114,7 @@ class SamplingFrequency(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency''' '''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off # fmt: off
FREQ_8000 = 0x01 FREQ_8000 = 0x01
FREQ_11025 = 0x02 FREQ_11025 = 0x02
FREQ_16000 = 0x03 FREQ_16000 = 0x03
FREQ_22050 = 0x04 FREQ_22050 = 0x04
@@ -430,7 +430,7 @@ class AseResponseCode(enum.IntEnum):
REJECTED_METADATA = 0x0B REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum): class AseReasonCode(enum.IntEnum):
@@ -1066,7 +1066,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter. # Readonly. Do nothing in the setter.
pass pass
def on_read(self, _: device.Connection) -> bytes: def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value return self.value
def __str__(self) -> str: def __str__(self) -> str:

View File

@@ -152,7 +152,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics) super().__init__(characteristics)
def on_sirk_read(self, _connection: device.Connection) -> bytes: def on_sirk_read(self, _connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT: if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
else: else:

View File

@@ -280,17 +280,14 @@ class AsyncRunner:
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs) coroutine = func(*args, **kwargs)
if queue is None: if queue is None:
# Create a task to run the coroutine # Spawn the coroutine as a task
async def run(): async def run():
try: try:
await coroutine await coroutine
except Exception: except Exception:
logger.warning( logger.exception(color("!!! Exception in wrapper:", "red"))
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
asyncio.create_task(run()) AsyncRunner.spawn(run())
else: else:
# Queue the coroutine to be awaited by the work queue # Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine) queue.enqueue(coroutine)

View File

@@ -48,7 +48,8 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy, PublishedAudioCapabilitiesServiceProxy,
) )
from .test_utils import TwoDevices from tests.test_utils import TwoDevices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -20,11 +20,10 @@ import logging
import os import os
import struct import struct
import pytest import pytest
from unittest.mock import Mock, ANY from unittest.mock import AsyncMock, Mock, ANY
from bumble.controller import Controller from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_server import Server
from bumble.link import LocalLink from bumble.link import LocalLink
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.host import Host from bumble.host import Host
@@ -120,9 +119,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE, Characteristic.READABLE,
123, 123,
) )
x = c.read_value(None) x = await c.read_value(None)
assert x == bytes([123]) assert x == bytes([123])
c.write_value(None, bytes([122])) await c.write_value(None, bytes([122]))
assert c.value == 122 assert c.value == 122
class FooProxy(CharacteristicProxy): class FooProxy(CharacteristicProxy):
@@ -152,7 +151,22 @@ async def test_characteristic_encoding():
bytes([123]), bytes([123]),
) )
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic]) async def async_read(connection):
return 0x05060708
async_characteristic = PackedCharacteristicAdapter(
Characteristic(
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
Characteristic.Properties.READ,
Characteristic.READABLE,
CharacteristicValue(read=async_read),
),
'>I',
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
)
server.add_service(service) server.add_service(service)
await client.power_on() await client.power_on()
@@ -184,6 +198,13 @@ async def test_characteristic_encoding():
await async_barrier() await async_barrier()
assert characteristic.value == bytes([50]) assert characteristic.value == bytes([50])
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
assert len(c2) == 1
c2 = c2[0]
cd2 = PackedCharacteristicAdapter(c2, ">I")
cd2v = await cd2.read_value()
assert cd2v == 0x05060708
last_change = None last_change = None
def on_change(value): def on_change(value):
@@ -285,7 +306,8 @@ async def test_attribute_getters():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_CharacteristicAdapter(): @pytest.mark.asyncio
async def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent # Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3]) v = bytes([1, 2, 3])
c = Characteristic( c = Characteristic(
@@ -296,11 +318,11 @@ def test_CharacteristicAdapter():
) )
a = CharacteristicAdapter(c) a = CharacteristicAdapter(c)
value = a.read_value(None) value = await a.read_value(None)
assert value == v assert value == v
v = bytes([3, 4, 5]) v = bytes([3, 4, 5])
a.write_value(None, v) await a.write_value(None, v)
assert c.value == v assert c.value == v
# Simple delegated adapter # Simple delegated adapter
@@ -308,11 +330,11 @@ def test_CharacteristicAdapter():
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
) )
value = a.read_value(None) value = await a.read_value(None)
assert value == bytes(reversed(v)) assert value == bytes(reversed(v))
v = bytes([3, 4, 5]) v = bytes([3, 4, 5])
a.write_value(None, v) await a.write_value(None, v)
assert a.value == bytes(reversed(v)) assert a.value == bytes(reversed(v))
# Packed adapter with single element format # Packed adapter with single element format
@@ -321,10 +343,10 @@ def test_CharacteristicAdapter():
c.value = v c.value = v
a = PackedCharacteristicAdapter(c, '>H') a = PackedCharacteristicAdapter(c, '>H')
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == v assert a.value == v
# Packed adapter with multi-element format # Packed adapter with multi-element format
@@ -334,10 +356,10 @@ def test_CharacteristicAdapter():
c.value = (v1, v2) c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH') a = PackedCharacteristicAdapter(c, '>HH')
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == (v1, v2) assert a.value == (v1, v2)
# Mapped adapter # Mapped adapter
@@ -348,10 +370,10 @@ def test_CharacteristicAdapter():
c.value = mapped c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == mapped assert a.value == mapped
# UTF-8 adapter # UTF-8 adapter
@@ -360,27 +382,49 @@ def test_CharacteristicAdapter():
c.value = v c.value = v
a = UTF8CharacteristicAdapter(c) a = UTF8CharacteristicAdapter(c)
value = a.read_value(None) value = await a.read_value(None)
assert value == ev assert value == ev
c.value = None c.value = None
a.write_value(None, ev) await a.write_value(None, ev)
assert a.value == v assert a.value == v
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_CharacteristicValue(): @pytest.mark.asyncio
async def test_CharacteristicValue():
b = bytes([1, 2, 3]) b = bytes([1, 2, 3])
c = CharacteristicValue(read=lambda _: b)
x = c.read(None) async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b assert x == b
result = [] m = Mock()
c = CharacteristicValue( c = CharacteristicValue(write=m)
write=lambda connection, value: result.append((connection, value))
)
z = object() z = object()
c.write(z, b) c.write(z, b)
assert result == [(z, b)] m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue_async():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
m = AsyncMock()
c = CharacteristicValue(write=m)
z = object()
await c.write(z, b)
m.assert_called_once_with(z, b)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -961,12 +1005,18 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
await test_read_write() await test_read_write()
await test_read_write2() await test_read_write2()
await test_subscribe_notify() await test_subscribe_notify()
await test_unsubscribe() await test_unsubscribe()
await test_characteristic_encoding() await test_characteristic_encoding()
await test_mtu_exchange() await test_mtu_exchange()
await test_CharacteristicValue()
await test_CharacteristicValue_async()
await test_CharacteristicAdapter()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1105,9 +1155,4 @@ def test_get_attribute_group():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
test_CharacteristicValue()
test_CharacteristicAdapter()
asyncio.run(async_main()) asyncio.run(async_main())