Typing att

This commit is contained in:
Josh Wu
2023-09-08 23:28:25 +08:00
parent 67418e649a
commit e559744f32
5 changed files with 243 additions and 136 deletions

View File

@@ -23,13 +23,14 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum
import functools import functools
import struct import struct
from pyee import EventEmitter from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value, HCI_Constant from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color from bumble.colors import color
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long # pylint: enable=line-too-long
# pylint: disable=invalid-name # pylint: disable=invalid-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -209,7 +211,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {} pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0 op_code = 0
name = None name: str
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu):
@@ -719,48 +721,68 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
''' '''
# -----------------------------------------------------------------------------
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def write(self, connection, value: bytes) -> None:
...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(EventEmitter): class Attribute(EventEmitter):
# Permission flags class Permissions(enum.IntFlag):
READABLE = 0x01 READABLE = 0x01
WRITEABLE = 0x02 WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04 READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08 WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10 READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20 WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40 READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80 WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = { @classmethod
READABLE: 'READABLE', def from_string(cls, permissions_str: str) -> Attribute.Permissions:
WRITEABLE: 'WRITEABLE', try:
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION', return functools.reduce(
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION', lambda x, y: x | Attribute.Permissions[y],
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION', permissions_str.replace('|', ',').split(","),
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION', Attribute.Permissions(0),
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION', )
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION', except TypeError as exc:
} # The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc
@staticmethod # Permission flags(legacy-use only)
def string_to_permissions(permissions_str: str): READABLE = Permissions.READABLE
try: WRITEABLE = Permissions.WRITEABLE
return functools.reduce( READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y), WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
permissions_str.split(","), READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
0, WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
) READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
except TypeError as exc: WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
) from exc
def __init__(self, attribute_type, permissions, value=b''): value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
self.end_group_handle = 0 self.end_group_handle = 0
if isinstance(permissions, str): if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions) self.permissions = Attribute.Permissions.from_string(permissions)
else: else:
self.permissions = permissions self.permissions = permissions
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
else: else:
self.value = value self.value = value
def encode_value(self, value): def encode_value(self, value: Any) -> bytes:
return value return value
def decode_value(self, value_bytes): def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def read_value(self, connection: Connection): def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
self.permissions & self.READ_REQUIRES_ENCRYPTION (self.permissions & self.READ_REQUIRES_ENCRYPTION)
) and not connection.encryption: and connection is not None
and not connection.encryption
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
) )
if ( if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION (self.permissions & self.READ_REQUIRES_AUTHENTICATION)
) and not connection.authenticated: and connection is not None
and not connection.authenticated
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
) )
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
) )
if read := getattr(self.value, 'read', None): if hasattr(self.value, 'read'):
try: try:
value = read(connection) # pylint: disable=not-callable value = self.value.read(connection)
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value) return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes): def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption: ) and not connection.encryption:
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes) value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None): if hasattr(self.value, 'write'):
try: try:
write(connection, value) # pylint: disable=not-callable self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle

View File

@@ -28,7 +28,7 @@ import enum
import functools import functools
import logging import logging
import struct import struct
from typing import Optional, Sequence, List from typing import Optional, Sequence, Iterable, List, Union
from .colors import color from .colors import color
from .core import UUID, get_dict_key_by_value from .core import UUID, get_dict_key_by_value
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services): def show_services(services: Iterable[Service]) -> None:
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -210,11 +210,11 @@ class Service(Attribute):
def __init__( def __init__(
self, self,
uuid, uuid: Union[str, UUID],
characteristics: List[Characteristic], characteristics: List[Characteristic],
primary=True, primary=True,
included_services: List[Service] = [], included_services: List[Service] = [],
): ) -> None:
# Convert the uuid to a UUID object if it isn't already # Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str): if isinstance(uuid, str):
uuid = UUID(uuid) uuid = UUID(uuid)
@@ -239,7 +239,7 @@ class Service(Attribute):
""" """
return None return None
def __str__(self): def __str__(self) -> str:
return ( return (
f'Service(handle=0x{self.handle:04X}, ' f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, ' f'end=0x{self.end_group_handle:04X}, '
@@ -255,9 +255,11 @@ class TemplateService(Service):
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID: Optional[UUID] = None UUID: UUID
def __init__(self, characteristics, primary=True): def __init__(
self, characteristics: List[Characteristic], primary: bool = True
) -> None:
super().__init__(self.UUID, characteristics, primary) super().__init__(self.UUID, characteristics, primary)
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service service: Service
def __init__(self, service): def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack( declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() '<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
) )
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
) )
self.service = service self.service = service
def __str__(self): def __str__(self) -> str:
return ( return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, ' f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, ' f'group_starting_handle=0x{self.service.handle:04X}, '
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}" f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
) )
def __str__(self): def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python # NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11. # versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join( return '|'.join(
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
def __init__( def __init__(
self, self,
uuid, uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties, properties: Characteristic.Properties,
permissions, permissions: Union[str, Attribute.Permissions],
value=b'', value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (), descriptors: Sequence[Descriptor] = (),
): ):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool: def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties return self.properties & properties == properties
def __str__(self): def __str__(self) -> str:
return ( return (
f'Characteristic(handle=0x{self.handle:04X}, ' f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, ' f'end=0x{self.end_group_handle:04X}, '
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic characteristic: Characteristic
def __init__(self, characteristic, value_handle): def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = ( declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle) struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes() + characteristic.uuid.to_pdu_bytes()
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle self.value_handle = value_handle
self.characteristic = characteristic self.characteristic = characteristic
def __str__(self): def __str__(self) -> str:
return ( return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, ' f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, ' f'value_handle=0x{self.value_handle:04X}, '
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber) return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self): def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic) wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})' return f'{self.__class__.__name__}({wrapped})'
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding Adapter that converts strings to/from bytes using UTF-8 encoding
''' '''
def encode_value(self, value): def encode_value(self, value: str) -> bytes:
return value.encode('utf-8') return value.encode('utf-8')
def decode_value(self, value): def decode_value(self, value: bytes) -> str:
return value.decode('utf-8') return value.decode('utf-8')
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
''' '''
def __str__(self): def __str__(self) -> str:
return ( return (
f'Descriptor(handle=0x{self.handle:04X}, ' f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, ' f'type={self.type}, '

View File

@@ -28,7 +28,18 @@ import asyncio
import logging import logging
import struct import struct
from datetime import datetime from datetime import datetime
from typing import List, Optional, Dict, Tuple, Callable, Union, Any from typing import (
List,
Optional,
Dict,
Tuple,
Callable,
Union,
Any,
Iterable,
Type,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE, GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
ClientCharacteristicConfigurationBits, ClientCharacteristicConfigurationBits,
TemplateService,
) )
if TYPE_CHECKING:
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies # Proxies
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter): class AttributeProxy(EventEmitter):
client: Client def __init__(
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
def __init__(self, client, handle, end_group_handle, attribute_type): ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.client = client self.client = client
self.handle = handle self.handle = handle
self.end_group_handle = end_group_handle self.end_group_handle = end_group_handle
self.type = attribute_type self.type = attribute_type
async def read_value(self, no_long_read=False): async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value( return self.decode_value(
await self.client.read_value(self.handle, no_long_read) await self.client.read_value(self.handle, no_long_read)
) )
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response self.handle, self.encode_value(value), with_response
) )
def encode_value(self, value): def encode_value(self, value: Any) -> bytes:
return value return value
def decode_value(self, value_bytes): def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def __str__(self): def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid): def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self) return self.client.get_characteristics_by_uuid(uuid, self)
def __str__(self): def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy): class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties properties: Characteristic.Properties
descriptors: List[DescriptorProxy] descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable] subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__( def __init__(
self, self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self) return await self.client.discover_descriptors(self)
async def subscribe( async def subscribe(
self, subscriber: Optional[Callable] = None, prefer_notify=True self,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
): ):
if subscriber is not None: if subscriber is not None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber) return await self.client.unsubscribe(self, subscriber)
def __str__(self): def __str__(self) -> str:
return ( return (
f'Characteristic(handle=0x{self.handle:04X}, ' f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, ' f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type): def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type) super().__init__(client, handle, 0, descriptor_type)
def __str__(self): def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})' return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies Base class for profile-specific service proxies
''' '''
SERVICE_CLASS: Type[TemplateService]
@classmethod @classmethod
def from_client(cls, client): def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client: class Client:
services: List[ServiceProxy] services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
def __init__(self, connection): def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
self.mtu_exchange_done = False self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
def send_gatt_pdu(self, pdu): def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command): async def send_command(self, command: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
) )
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request): async def send_request(self, request: ATT_PDU):
logger.debug( logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
) )
@@ -279,14 +302,14 @@ class Client:
return response return response
def send_confirmation(self, confirmation): def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug( logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}' f'{confirmation}'
) )
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu): async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
if mtu < ATT_DEFAULT_MTU: if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu return self.connection.att_mtu
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service=None): def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> List[CharacteristicProxy]:
services = [service] if service else self.services services = [service] if service else self.services
return [ return [
c c
@@ -363,7 +388,7 @@ class Client:
if not already_known: if not already_known:
self.services.append(service) self.services.append(service)
async def discover_services(self, uuids=None) -> List[ServiceProxy]: async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.1 Discover All Primary Services See Vol 3, Part G - 4.4.1 Discover All Primary Services
''' '''
@@ -435,7 +460,7 @@ class Client:
return services return services
async def discover_service(self, uuid): async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
''' '''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}' f'{HCI_Constant.error_name(response.error_code)}'
) )
# TODO raise appropriate exception # TODO raise appropriate exception
return return []
break break
for attribute_handle, end_group_handle in response.handles_information: for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning( logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}' f'bogus handle values: {attribute_handle} {end_group_handle}'
) )
return return []
# Create a service proxy for this service # Create a service proxy for this service
service = ServiceProxy( service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors return descriptors
async def discover_attributes(self): async def discover_attributes(self) -> List[AttributeProxy]:
''' '''
Discover all attributes, regardless of type Discover all attributes, regardless of type
''' '''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left # No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True) await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False): async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes # Return the value as bytes
return attribute_value return attribute_value
async def read_characteristics_by_uuid(self, uuid, service): async def read_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy]
) -> List[bytes]:
''' '''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
''' '''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values return characteristics_values
async def write_value(self, attribute, value, with_response=False): async def write_value(
self,
attribute: Union[int, AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
''' '''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value Value
@@ -990,7 +1024,7 @@ class Client:
) )
) )
def on_gatt_pdu(self, att_pdu): def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
) )
@@ -1013,6 +1047,7 @@ class Client:
return return
# Return the response to the coroutine that is waiting for it # Return the response to the coroutine that is waiting for it
assert self.pending_response is not None
self.pending_response.set_result(att_pdu) self.pending_response.set_result(att_pdu)
else: else:
handler_name = f'on_{att_pdu.name.lower()}' handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication # Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation()) self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes): def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = ( self.cached_values[attribute_handle] = (
datetime.now(), datetime.now(),
value, value,

View File

@@ -23,11 +23,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
import struct import struct
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR, ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR, ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS, ATT_REQUESTS,
ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR, ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error, ATT_Error,
@@ -73,6 +75,8 @@ from .gatt import (
Service, Service,
) )
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(EventEmitter):
attributes: List[Attribute] attributes: List[Attribute]
services: List[Service]
attributes_by_handle: Dict[int, Attribute]
subscribers: Dict[int, Dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
def __init__(self, device): def __init__(self, device: Device) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.services = [] self.services = []
@@ -107,16 +116,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None) self.pending_confirmations = defaultdict(lambda: None)
def __str__(self): def __str__(self) -> str:
return "\n".join(map(str, self.attributes)) return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu): def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self): def next_handle(self) -> int:
return 1 + len(self.attributes) return 1 + len(self.attributes)
def get_advertising_service_data(self): def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return { return {
attribute: data attribute: data
for attribute in self.attributes for attribute in self.attributes
@@ -124,7 +133,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data()) and (data := attribute.get_advertising_data())
} }
def get_attribute(self, handle): def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle) attribute = self.attributes_by_handle.get(handle)
if attribute: if attribute:
return attribute return attribute
@@ -173,12 +182,17 @@ class Server(EventEmitter):
return next( return next(
( (
(attribute, self.get_attribute(attribute.characteristic.handle)) (
attribute,
self.get_attribute(attribute.characteristic.handle),
) # type: ignore
for attribute in map( for attribute in map(
self.get_attribute, self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1), range(service_handle.handle, service_handle.end_group_handle + 1),
) )
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE if attribute is not None
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid and attribute.characteristic.uuid == characteristic_uuid
), ),
None, None,
@@ -197,7 +211,7 @@ class Server(EventEmitter):
return next( return next(
( (
attribute attribute # type: ignore
for attribute in map( for attribute in map(
self.get_attribute, self.get_attribute,
range( range(
@@ -205,12 +219,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1, characteristic_value.end_group_handle + 1,
), ),
) )
if attribute.type == descriptor_uuid if attribute is not None and attribute.type == descriptor_uuid
), ),
None, None,
) )
def add_attribute(self, attribute): def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute # Assign a handle to this attribute
attribute.handle = self.next_handle() attribute.handle = self.next_handle()
attribute.end_group_handle = ( attribute.end_group_handle = (
@@ -220,7 +234,7 @@ class Server(EventEmitter):
# Add this attribute to the list # Add this attribute to the list
self.attributes.append(attribute) self.attributes.append(attribute)
def add_service(self, service: Service): def add_service(self, service: Service) -> None:
# Add the service attribute to the DB # Add the service attribute to the DB
self.add_attribute(service) self.add_attribute(service)
@@ -285,11 +299,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle service.end_group_handle = self.attributes[-1].handle
self.services.append(service) self.services.append(service)
def add_services(self, services): def add_services(self, services: Iterable[Service]) -> None:
for service in services: for service in services:
self.add_service(service) self.add_service(service)
def read_cccd(self, connection, characteristic): def read_cccd(
self, connection: Optional[Connection], characteristic: Characteristic
) -> bytes:
if connection is None: if connection is None:
return bytes([0, 0]) return bytes([0, 0])
@@ -300,7 +316,12 @@ class Server(EventEmitter):
return cccd or bytes([0, 0]) return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value): def write_cccd(
self,
connection: Connection,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug( logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, ' f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}' f'handle=0x{characteristic.handle:04X}: {value.hex()}'
@@ -327,13 +348,19 @@ class Server(EventEmitter):
indicate_enabled, indicate_enabled,
) )
def send_response(self, connection, response): def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}' f'GATT Response from server: [0x{connection.handle:04X}] {response}'
) )
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False): async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -370,7 +397,13 @@ class Server(EventEmitter):
) )
self.send_gatt_pdu(connection.handle, bytes(notification)) self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False): async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -411,15 +444,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_confirmations[ pending_confirmation = self.pending_confirmations[
connection.handle connection.handle
] = asyncio.get_running_loop().create_future() ] = asyncio.get_running_loop().create_future()
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for( await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +458,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers( async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False self,
): indicate: bool,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription # Get all the connections for which there's at least one subscription
connections = [ connections = [
connection connection
@@ -450,13 +485,23 @@ class Server(EventEmitter):
] ]
) )
async def notify_subscribers(self, attribute, value=None, force=False): async def notify_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(False, attribute, value, force) return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False): async def indicate_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(True, attribute, value, force) return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection): def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers: if connection.handle in self.subscribers:
del self.subscribers[connection.handle] del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores: if connection.handle in self.indication_semaphores:
@@ -464,7 +509,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations: if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle] del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, connection, att_pdu): def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}' handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
@@ -506,7 +551,7 @@ class Server(EventEmitter):
####################################################### #######################################################
# ATT handlers # ATT handlers
####################################################### #######################################################
def on_att_request(self, connection, pdu): def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
''' '''
Handler for requests without a more specific handler Handler for requests without a more specific handler
''' '''
@@ -679,7 +724,6 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:

View File

@@ -891,10 +891,10 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_attribute_string_to_permissions(): def test_permissions_from_string():
assert Attribute.string_to_permissions('READABLE') == 1 assert Attribute.Permissions.from_string('READABLE') == 1
assert Attribute.string_to_permissions('WRITEABLE') == 2 assert Attribute.Permissions.from_string('WRITEABLE') == 2
assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3 assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------