diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 0c69b12..d5a8ec7 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -207,11 +207,11 @@ class CharacteristicProxy(AttributeProxy): return await self.client.subscribe(self, subscriber, prefer_notify) - async def unsubscribe(self, subscriber=None): + async def unsubscribe(self, subscriber=None, force=False): if subscriber in self.subscribers: subscriber = self.subscribers.pop(subscriber) - return await self.client.unsubscribe(self, subscriber) + return await self.client.unsubscribe(self, subscriber, force) def __str__(self) -> str: return ( @@ -262,10 +262,8 @@ class Client: self.request_semaphore = asyncio.Semaphore(1) self.pending_request = None self.pending_response = None - self.notification_subscribers = ( - {} - ) # Notification subscribers, by attribute handle - self.indication_subscribers = {} # Indication subscribers, by attribute handle + self.notification_subscribers = {} # Subscriber set, by attribute handle + self.indication_subscribers = {} # Subscriber set, by attribute handle self.services = [] self.cached_values = {} @@ -836,6 +834,7 @@ class Client: subscriber_set = subscribers.setdefault(characteristic.handle, set()) if subscriber is not None: subscriber_set.add(subscriber) + # Add the characteristic as a subscriber, which will result in the # characteristic emitting an 'update' event when a notification or indication # is received @@ -847,7 +846,14 @@ class Client: self, characteristic: CharacteristicProxy, subscriber: Optional[Callable[[bytes], Any]] = None, + force: bool = False, ) -> None: + ''' + Unsubscribe from a characteristic. + + If `force` is True, this will write zeros to the CCCD when there are no + subscribers left, even if there were already no registered subscribers. + ''' # If we haven't already discovered the descriptors for this characteristic, # do it now if not characteristic.descriptors_discovered: @@ -861,25 +867,39 @@ class Client: logger.warning('unsubscribing from characteristic with no CCCD descriptor') return + # Check if the characteristic has subscribers + if not ( + characteristic.handle in self.notification_subscribers + or characteristic.handle in self.indication_subscribers + ): + if not force: + return + + # Remove the subscriber(s) if subscriber is not None: # Remove matching subscriber from subscriber sets for subscriber_set in ( self.notification_subscribers, self.indication_subscribers, ): - subscribers = subscriber_set.get(characteristic.handle, set()) - if subscriber in subscribers: + if ( + subscribers := subscriber_set.get(characteristic.handle) + ) and subscriber in subscribers: subscribers.remove(subscriber) # Cleanup if we removed the last one if not subscribers: del subscriber_set[characteristic.handle] else: - # Remove all subscribers for this attribute from the sets! + # Remove all subscribers for this attribute from the sets self.notification_subscribers.pop(characteristic.handle, None) self.indication_subscribers.pop(characteristic.handle, None) - if not self.notification_subscribers and not self.indication_subscribers: + # Update the CCCD + if not ( + characteristic.handle in self.notification_subscribers + or characteristic.handle in self.indication_subscribers + ): # No more subscribers left await self.write_value(cccd, b'\x00\x00', with_response=True) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index d9f6d60..85b40a9 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -20,6 +20,7 @@ import logging import os import struct import pytest +from unittest.mock import Mock, ANY from bumble.controller import Controller from bumble.gatt_client import CharacteristicProxy @@ -763,6 +764,83 @@ async def test_subscribe_notify(): assert not c3._called_3 +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_unsubscribe(): + [client, server] = LinkedDevices().devices[:2] + + characteristic1 = Characteristic( + 'FDB159DB-036C-49E3-B3DB-6325AC750806', + Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, + Characteristic.READABLE, + bytes([1, 2, 3]), + ) + characteristic2 = Characteristic( + '3234C4F4-3F34-4616-8935-45A50EE05DEB', + Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, + Characteristic.READABLE, + bytes([1, 2, 3]), + ) + + service1 = Service( + '3A657F47-D34F-46B3-B1EC-698E29B6B829', + [characteristic1, characteristic2], + ) + server.add_services([service1]) + + mock1 = Mock() + characteristic1.on('subscription', mock1) + mock2 = Mock() + characteristic2.on('subscription', mock2) + + await client.power_on() + await server.power_on() + connection = await client.connect(server.random_address) + peer = Peer(connection) + + await peer.discover_services() + await peer.discover_characteristics() + c = peer.get_characteristics_by_uuid(characteristic1.uuid) + assert len(c) == 1 + c1 = c[0] + c = peer.get_characteristics_by_uuid(characteristic2.uuid) + assert len(c) == 1 + c2 = c[0] + + await c1.subscribe() + await async_barrier() + mock1.assert_called_once_with(ANY, True, False) + + await c2.subscribe() + await async_barrier() + mock2.assert_called_once_with(ANY, True, False) + + mock1.reset_mock() + await c1.unsubscribe() + await async_barrier() + mock1.assert_called_once_with(ANY, False, False) + + mock2.reset_mock() + await c2.unsubscribe() + await async_barrier() + mock2.assert_called_once_with(ANY, False, False) + + mock1.reset_mock() + await c1.unsubscribe() + await async_barrier() + mock1.assert_not_called() + + mock2.reset_mock() + await c2.unsubscribe() + await async_barrier() + mock2.assert_not_called() + + mock1.reset_mock() + await c1.unsubscribe(force=True) + await async_barrier() + mock1.assert_called_once_with(ANY, False, False) + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_mtu_exchange(): @@ -886,6 +964,7 @@ async def async_main(): await test_read_write() await test_read_write2() await test_subscribe_notify() + await test_unsubscribe() await test_characteristic_encoding() await test_mtu_exchange()