Merge pull request #71 from mogenson/prefer-notify

Add prefer_notify option to gatt_client subscribe()
This commit is contained in:
Michael Mogenson
2022-11-13 19:53:09 -05:00
committed by GitHub
3 changed files with 62 additions and 34 deletions

View File

@@ -294,8 +294,8 @@ class Peer:
async def discover_attributes(self): async def discover_attributes(self):
return await self.gatt_client.discover_attributes() return await self.gatt_client.discover_attributes()
async def subscribe(self, characteristic, subscriber=None): async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
return await self.gatt_client.subscribe(characteristic, subscriber) return await self.gatt_client.subscribe(characteristic, subscriber, prefer_notify)
async def unsubscribe(self, characteristic, subscriber=None): async def unsubscribe(self, characteristic, subscriber=None):
return await self.gatt_client.unsubscribe(characteristic, subscriber) return await self.gatt_client.unsubscribe(characteristic, subscriber)

View File

@@ -26,20 +26,17 @@
import asyncio import asyncio
import logging import logging
import struct import struct
from colors import color from colors import color
from .core import ProtocolError, TimeoutError
from .hci import *
from .att import * from .att import *
from .gatt import ( from .core import InvalidStateError, ProtocolError, TimeoutError
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, from .gatt import (GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, ClientCharacteristicConfigurationBits)
Characteristic, from .hci import *
ClientCharacteristicConfigurationBits
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -115,7 +112,7 @@ class CharacteristicProxy(AttributeProxy):
async def discover_descriptors(self): async def discover_descriptors(self):
return await self.client.discover_descriptors(self) return await self.client.discover_descriptors(self)
async def subscribe(self, subscriber=None): async def subscribe(self, subscriber=None, prefer_notify=True):
if subscriber is not None: if subscriber is not None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
# We already have a proxy subscriber # We already have a proxy subscriber
@@ -129,7 +126,7 @@ class CharacteristicProxy(AttributeProxy):
self.subscribers[subscriber] = on_change self.subscribers[subscriber] = on_change
subscriber = on_change subscriber = on_change
return await self.client.subscribe(self, subscriber) return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None): async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers: if subscriber in self.subscribers:
@@ -547,7 +544,7 @@ class Client:
return attributes return attributes
async def subscribe(self, characteristic, subscriber=None): async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic, do it now # If we haven't already discovered the descriptors for this characteristic, do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic) await self.discover_descriptors(characteristic)
@@ -558,23 +555,32 @@ class Client:
logger.warning('subscribing to characteristic with no CCCD descriptor') logger.warning('subscribing to characteristic with no CCCD descriptor')
return return
# Set the subscription bits and select the subscriber set if (
bits = ClientCharacteristicConfigurationBits.DEFAULT characteristic.properties & Characteristic.NOTIFY
subscriber_sets = [] and characteristic.properties & Characteristic.INDICATE
if characteristic.properties & Characteristic.NOTIFY: ):
bits |= ClientCharacteristicConfigurationBits.NOTIFICATION if prefer_notify:
subscriber_sets.append(self.notification_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.NOTIFICATION
if characteristic.properties & Characteristic.INDICATE: subscribers = self.notification_subscribers
bits |= ClientCharacteristicConfigurationBits.INDICATION else:
subscriber_sets.append(self.indication_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
elif characteristic.properties & Characteristic.NOTIFY:
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
subscribers = self.notification_subscribers
elif characteristic.properties & Characteristic.INDICATE:
bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
else:
raise InvalidStateError("characteristic is not notify or indicate")
# Add subscribers to the sets # Add subscribers to the sets
for subscriber_set in subscriber_sets: subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None: if subscriber is not None:
subscriber_set.add(subscriber) subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the characteristic # Add the characteristic as a subscriber, which will result in the characteristic
# emitting an 'update' event when a notification or indication is received # emitting an 'update' event when a notification or indication is received
subscriber_set.add(characteristic) subscriber_set.add(characteristic)
await self.write_value(cccd, struct.pack('<H', bits), with_response=True) await self.write_value(cccd, struct.pack('<H', bits), with_response=True)

View File

@@ -612,14 +612,25 @@ async def test_subscribe_notify():
await async_barrier() await async_barrier()
assert not c2._called assert not c2._called
c3._called = False
c3._called_2 = False
c3._called_3 = False
c3._last_update = None
c3._last_update_2 = None
c3._last_update_3 = None
def on_c3_update(value): def on_c3_update(value):
c3._called = True c3._called = True
c3._last_update = value c3._last_update = value
def on_c3_update_2(value): def on_c3_update_2(value): # for notify
c3._called_2 = True c3._called_2 = True
c3._last_update_2 = value c3._last_update_2 = value
def on_c3_update_3(value): # for indicate
c3._called_3 = True
c3._last_update_3 = value
c3.on('update', on_c3_update) c3.on('update', on_c3_update)
await peer.subscribe(c3, on_c3_update_2) await peer.subscribe(c3, on_c3_update_2)
await async_barrier() await async_barrier()
@@ -629,22 +640,33 @@ async def test_subscribe_notify():
assert c3._last_update == characteristic3.value assert c3._last_update == characteristic3.value
assert c3._called_2 assert c3._called_2
assert c3._last_update_2 == characteristic3.value assert c3._last_update_2 == characteristic3.value
assert not c3._called_3
c3._called = False
c3._called_2 = False
c3._called_3 = False
await peer.unsubscribe(c3)
await peer.subscribe(c3, on_c3_update_3, prefer_notify=False)
await async_barrier()
characteristic3.value = bytes([1, 2, 3]) characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3) await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier() await async_barrier()
assert c3._called assert c3._called
assert c3._last_update == characteristic3.value assert c3._last_update == characteristic3.value
assert c3._called_2 assert not c3._called_2
assert c3._last_update_2 == characteristic3.value assert c3._called_3
assert c3._last_update_3 == characteristic3.value
c3._called = False c3._called = False
c3._called_2 = False c3._called_2 = False
c3._called_3 = False
await peer.unsubscribe(c3) await peer.unsubscribe(c3)
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3) await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3) await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier() await async_barrier()
assert not c3._called assert not c3._called
assert not c3._called_2 assert not c3._called_2
assert not c3._called_3
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------