This commit is contained in:
Gilles Boccon-Gibod
2023-11-29 19:19:40 -08:00
parent 24524d88cb
commit 58c9c4f590
2 changed files with 109 additions and 10 deletions

View File

@@ -207,11 +207,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None):
async def unsubscribe(self, subscriber=None, force=False):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
return await self.client.unsubscribe(self, subscriber, force)
def __str__(self) -> str:
return (
@@ -262,10 +262,8 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.notification_subscribers = {} # Subscriber set, by attribute handle
self.indication_subscribers = {} # Subscriber set, by attribute handle
self.services = []
self.cached_values = {}
@@ -836,6 +834,7 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
@@ -847,7 +846,14 @@ class Client:
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -861,25 +867,39 @@ class Client:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return
# Remove the subscriber(s)
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, set())
if subscriber in subscribers:
if (
subscribers := subscriber_set.get(characteristic.handle)
) and subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
# Remove all subscribers for this attribute from the sets
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# Update the CCCD
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)

View File

@@ -20,6 +20,7 @@ import logging
import os
import struct
import pytest
from unittest.mock import Mock, ANY
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
@@ -763,6 +764,83 @@ async def test_subscribe_notify():
assert not c3._called_3
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
characteristic2 = Characteristic(
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic1, characteristic2],
)
server.add_services([service1])
mock1 = Mock()
characteristic1.on('subscription', mock1)
mock2 = Mock()
characteristic2.on('subscription', mock2)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert len(c) == 1
c2 = c[0]
await c1.subscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, True, False)
await c2.subscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, True, False)
mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)
mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, False, False)
mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_not_called()
mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_not_called()
mock1.reset_mock()
await c1.unsubscribe(force=True)
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
@@ -886,6 +964,7 @@ async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()