Merge pull request #348 from zxzxwu/gattc

Typing GATT Client and Device Peer
This commit is contained in:
zxzxwu
2023-11-29 15:09:40 +08:00
committed by GitHub
2 changed files with 91 additions and 34 deletions

View File

@@ -23,6 +23,7 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Iterable
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@@ -32,6 +33,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Type, Type,
TypeVar,
Set, Set,
Union, Union,
cast, cast,
@@ -440,8 +442,11 @@ class LePhyOptions:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_PROXY_CLASS = TypeVar('_PROXY_CLASS', bound=gatt_client.ProfileServiceProxy)
class Peer: class Peer:
def __init__(self, connection): def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
# Create a GATT client for the connection # Create a GATT client for the connection
@@ -449,77 +454,113 @@ class Peer:
connection.gatt_client = self.gatt_client connection.gatt_client = self.gatt_client
@property @property
def services(self): def services(self) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.services return self.gatt_client.services
async def request_mtu(self, mtu): async def request_mtu(self, mtu: int) -> int:
mtu = await self.gatt_client.request_mtu(mtu) mtu = await self.gatt_client.request_mtu(mtu)
self.connection.emit('connection_att_mtu_update') self.connection.emit('connection_att_mtu_update')
return mtu return mtu
async def discover_service(self, uuid): async def discover_service(
self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid) return await self.gatt_client.discover_service(uuid)
async def discover_services(self, uuids=()): async def discover_services(
self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids) return await self.gatt_client.discover_services(uuids)
async def discover_included_services(self, service): async def discover_included_services(
self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service) return await self.gatt_client.discover_included_services(service)
async def discover_characteristics(self, uuids=(), service=None): async def discover_characteristics(
self,
uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics( return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service uuids=uuids, service=service
) )
async def discover_descriptors( async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None self,
characteristic: Optional[gatt_client.CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
): ):
return await self.gatt_client.discover_descriptors( return await self.gatt_client.discover_descriptors(
characteristic, start_handle, end_handle characteristic, start_handle, end_handle
) )
async def discover_attributes(self): async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes() return await self.gatt_client.discover_attributes()
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): async def subscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
return await self.gatt_client.subscribe( return await self.gatt_client.subscribe(
characteristic, subscriber, prefer_notify characteristic, subscriber, prefer_notify
) )
async def unsubscribe(self, characteristic, subscriber=None): async def unsubscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
return await self.gatt_client.unsubscribe(characteristic, subscriber) return await self.gatt_client.unsubscribe(characteristic, subscriber)
async def read_value(self, attribute): async def read_value(
self, attribute: Union[int, gatt_client.AttributeProxy]
) -> bytes:
return await self.gatt_client.read_value(attribute) return await self.gatt_client.read_value(attribute)
async def write_value(self, attribute, value, with_response=False): async def write_value(
self,
attribute: Union[int, gatt_client.AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
return await self.gatt_client.write_value(attribute, value, with_response) return await self.gatt_client.write_value(attribute, value, with_response)
async def read_characteristics_by_uuid(self, uuid, service=None): async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service) return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid) return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid(self, uuid, service=None): def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[gatt_client.CharacteristicProxy]:
return self.gatt_client.get_characteristics_by_uuid(uuid, service) return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(self, proxy_class): def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return proxy_class.from_client(self.gatt_client) return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
async def discover_service_and_create_proxy(self, proxy_class): async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
# Discover the first matching service and its characteristics # Discover the first matching service and its characteristics
services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID) services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID)
if services: if services:
service = services[0] service = services[0]
await service.discover_characteristics() await service.discover_characteristics()
return self.create_service_proxy(proxy_class) return self.create_service_proxy(proxy_class)
return None
async def sustain(self, timeout=None): async def sustain(self, timeout: Optional[float] = None) -> None:
await self.connection.sustain(timeout) await self.connection.sustain(timeout)
# [Classic only] # [Classic only]
async def request_name(self): async def request_name(self) -> str:
return await self.connection.request_remote_name() return await self.connection.request_remote_name()
async def __aenter__(self): async def __aenter__(self):
@@ -532,7 +573,7 @@ class Peer:
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
pass pass
def __str__(self): def __str__(self) -> str:
return f'{self.connection.peer_address} as {self.connection.role_name}' return f'{self.connection.peer_address} as {self.connection.role_name}'
@@ -732,7 +773,7 @@ class Connection(CompositeEventEmitter):
async def switch_role(self, role: int) -> None: async def switch_role(self, role: int) -> None:
return await self.device.switch_role(self, role) return await self.device.switch_role(self, role)
async def sustain(self, timeout=None): async def sustain(self, timeout: Optional[float] = None) -> None:
"""Idles the current task waiting for a disconnect or timeout""" """Idles the current task waiting for a disconnect or timeout"""
abort = asyncio.get_running_loop().create_future() abort = asyncio.get_running_loop().create_future()

View File

@@ -38,6 +38,7 @@ from typing import (
Any, Any,
Iterable, Iterable,
Type, Type,
Set,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -128,7 +129,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy] included_services: List[ServiceProxy]
@staticmethod @staticmethod
def from_client(service_class, client, service_uuid): def from_client(service_class, client: Client, service_uuid: UUID):
# The service and its characteristics are considered to have already been # The service and its characteristics are considered to have already been
# discovered # discovered
services = client.get_services_by_uuid(service_uuid) services = client.get_services_by_uuid(service_uuid)
@@ -246,8 +247,12 @@ class ProfileServiceProxy:
class Client: class Client:
services: List[ServiceProxy] services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]] notification_subscribers: Dict[
indication_subscribers: Dict[int, Callable[[bytes], Any]] int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]] pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU] pending_request: Optional[ATT_PDU]
@@ -682,8 +687,8 @@ class Client:
async def discover_descriptors( async def discover_descriptors(
self, self,
characteristic: Optional[CharacteristicProxy] = None, characteristic: Optional[CharacteristicProxy] = None,
start_handle=None, start_handle: Optional[int] = None,
end_handle=None, end_handle: Optional[int] = None,
) -> List[DescriptorProxy]: ) -> List[DescriptorProxy]:
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -789,7 +794,12 @@ class Client:
return attributes return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -833,7 +843,11 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True) await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None): async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -853,7 +867,7 @@ class Client:
self.notification_subscribers, self.notification_subscribers,
self.indication_subscribers, self.indication_subscribers,
): ):
subscribers = subscriber_set.get(characteristic.handle, []) subscribers = subscriber_set.get(characteristic.handle, set())
if subscriber in subscribers: if subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
@@ -871,7 +885,7 @@ class Client:
async def read_value( async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any: ) -> bytes:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -1067,7 +1081,7 @@ class Client:
def on_att_handle_value_notification(self, notification): def on_att_handle_value_notification(self, notification):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get( subscribers = self.notification_subscribers.get(
notification.attribute_handle, [] notification.attribute_handle, set()
) )
if not subscribers: if not subscribers:
logger.warning('!!! received notification with no subscriber') logger.warning('!!! received notification with no subscriber')
@@ -1081,7 +1095,9 @@ class Client:
def on_att_handle_value_indication(self, indication): def on_att_handle_value_indication(self, indication):
# Call all subscribers # Call all subscribers
subscribers = self.indication_subscribers.get(indication.attribute_handle, []) subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
if not subscribers: if not subscribers:
logger.warning('!!! received indication with no subscriber') logger.warning('!!! received indication with no subscriber')