forked from auracaster/bumble_mirror
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d1e119f176 | |||
| 2fc7a0bf04 | |||
| 073757d5dd | |||
| 20dedbd923 | |||
| 0edd6b731f | |||
| 80569bc9f3 |
+2
-2
@@ -90,7 +90,7 @@ class SnoopPacketReader:
|
|||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
|
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
|
||||||
@click.argument('filename')
|
@click.argument('filename')
|
||||||
def show(format, filename):
|
def main(format, filename):
|
||||||
input = open(filename, 'rb')
|
input = open(filename, 'rb')
|
||||||
if format == 'h4':
|
if format == 'h4':
|
||||||
packet_reader = PacketReader(input)
|
packet_reader = PacketReader(input)
|
||||||
@@ -117,4 +117,4 @@ def show(format, filename):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
show()
|
main()
|
||||||
|
|||||||
+4
-6
@@ -29,6 +29,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import click
|
||||||
import usb1
|
import usb1
|
||||||
from colors import color
|
from colors import color
|
||||||
|
|
||||||
@@ -149,6 +150,8 @@ def is_bluetooth_hci(device):
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
@click.command()
|
||||||
|
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
|
||||||
def main(verbose):
|
def main(verbose):
|
||||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||||
|
|
||||||
@@ -233,9 +236,4 @@ def main(verbose):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if len(sys.argv) == 2 and sys.argv[1] == '--verbose':
|
main()
|
||||||
verbose = True
|
|
||||||
else:
|
|
||||||
verbose = False
|
|
||||||
|
|
||||||
main(verbose)
|
|
||||||
|
|||||||
+18
-4
@@ -700,16 +700,26 @@ class Attribute(EventEmitter):
|
|||||||
else:
|
else:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
def encode_value(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def decode_value(self, value_bytes):
|
||||||
|
return value_bytes
|
||||||
|
|
||||||
def read_value(self, connection):
|
def read_value(self, connection):
|
||||||
if read := getattr(self.value, 'read', None):
|
if read := getattr(self.value, 'read', None):
|
||||||
try:
|
try:
|
||||||
return read(connection)
|
value = read(connection)
|
||||||
except ATT_Error as error:
|
except ATT_Error as error:
|
||||||
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
|
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
|
||||||
else:
|
else:
|
||||||
return self.value
|
value = self.value
|
||||||
|
|
||||||
|
return self.encode_value(value)
|
||||||
|
|
||||||
|
def write_value(self, connection, value_bytes):
|
||||||
|
value = self.decode_value(value_bytes)
|
||||||
|
|
||||||
def write_value(self, connection, value):
|
|
||||||
if write := getattr(self.value, 'write', None):
|
if write := getattr(self.value, 'write', None):
|
||||||
try:
|
try:
|
||||||
write(connection, value)
|
write(connection, value)
|
||||||
@@ -721,7 +731,11 @@ class Attribute(EventEmitter):
|
|||||||
self.emit('write', connection, value)
|
self.emit('write', connection, value)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if len(self.value) > 0:
|
if type(self.value) is bytes:
|
||||||
|
value_str = self.value.hex()
|
||||||
|
else:
|
||||||
|
value_str = str(self.value)
|
||||||
|
if value_str:
|
||||||
value_string = f', value={self.value.hex()}'
|
value_string = f', value={self.value.hex()}'
|
||||||
else:
|
else:
|
||||||
value_string = ''
|
value_string = ''
|
||||||
|
|||||||
+8
-8
@@ -1210,17 +1210,17 @@ class Device(CompositeEventEmitter):
|
|||||||
def add_services(self, services):
|
def add_services(self, services):
|
||||||
self.gatt_server.add_services(services)
|
self.gatt_server.add_services(services)
|
||||||
|
|
||||||
async def notify_subscriber(self, connection, attribute, force=False):
|
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||||
await self.gatt_server.notify_subscriber(connection, attribute, force)
|
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def notify_subscribers(self, attribute, force=False):
|
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||||
await self.gatt_server.notify_subscribers(attribute, force)
|
await self.gatt_server.notify_subscribers(attribute, value, force)
|
||||||
|
|
||||||
async def indicate_subscriber(self, connection, attribute, force=False):
|
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||||
await self.gatt_server.indicate_subscriber(connection, attribute, force)
|
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def indicate_subscribers(self, attribute):
|
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||||
await self.gatt_server.indicate_subscribers(attribute)
|
await self.gatt_server.indicate_subscribers(attribute, value, force)
|
||||||
|
|
||||||
@host_event_handler
|
@host_event_handler
|
||||||
def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters):
|
def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters):
|
||||||
|
|||||||
+32
-4
@@ -303,6 +303,7 @@ class CharacteristicAdapter:
|
|||||||
'''
|
'''
|
||||||
def __init__(self, characteristic):
|
def __init__(self, characteristic):
|
||||||
self.wrapped_characteristic = characteristic
|
self.wrapped_characteristic = characteristic
|
||||||
|
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||||
|
|
||||||
if (
|
if (
|
||||||
asyncio.iscoroutinefunction(characteristic.read_value) and
|
asyncio.iscoroutinefunction(characteristic.read_value) and
|
||||||
@@ -317,11 +318,21 @@ class CharacteristicAdapter:
|
|||||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||||
self.subscribe = self.wrapped_subscribe
|
self.subscribe = self.wrapped_subscribe
|
||||||
|
|
||||||
|
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||||
|
self.unsubscribe = self.wrapped_unsubscribe
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.wrapped_characteristic, name)
|
return getattr(self.wrapped_characteristic, name)
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
|
if name in {
|
||||||
|
'wrapped_characteristic',
|
||||||
|
'subscribers',
|
||||||
|
'read_value',
|
||||||
|
'write_value',
|
||||||
|
'subscribe',
|
||||||
|
'unsubscribe'
|
||||||
|
}:
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
else:
|
else:
|
||||||
setattr(self.wrapped_characteristic, name, value)
|
setattr(self.wrapped_characteristic, name, value)
|
||||||
@@ -345,9 +356,26 @@ class CharacteristicAdapter:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def wrapped_subscribe(self, subscriber=None):
|
def wrapped_subscribe(self, subscriber=None):
|
||||||
return self.wrapped_characteristic.subscribe(
|
if subscriber is not None:
|
||||||
None if subscriber is None else lambda value: subscriber(self.decode_value(value))
|
if subscriber in self.subscribers:
|
||||||
)
|
# We already have a proxy subscriber
|
||||||
|
subscriber = self.subscribers[subscriber]
|
||||||
|
else:
|
||||||
|
# Create and register a proxy that will decode the value
|
||||||
|
original_subscriber = subscriber
|
||||||
|
|
||||||
|
def on_change(value):
|
||||||
|
original_subscriber(self.decode_value(value))
|
||||||
|
self.subscribers[subscriber] = on_change
|
||||||
|
subscriber = on_change
|
||||||
|
|
||||||
|
return self.wrapped_characteristic.subscribe(subscriber)
|
||||||
|
|
||||||
|
def wrapped_unsubscribe(self, subscriber=None):
|
||||||
|
if subscriber in self.subscribers:
|
||||||
|
subscriber = self.subscribers.pop(subscriber)
|
||||||
|
|
||||||
|
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
wrapped = str(self.wrapped_characteristic)
|
wrapped = str(self.wrapped_characteristic)
|
||||||
|
|||||||
+39
-9
@@ -58,10 +58,16 @@ class AttributeProxy(EventEmitter):
|
|||||||
self.type = attribute_type
|
self.type = attribute_type
|
||||||
|
|
||||||
async def read_value(self, no_long_read=False):
|
async def read_value(self, no_long_read=False):
|
||||||
return await self.client.read_value(self.handle, no_long_read)
|
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
|
||||||
|
|
||||||
async def write_value(self, value, with_response=False):
|
async def write_value(self, value, with_response=False):
|
||||||
return await self.client.write_value(self.handle, value, with_response)
|
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
|
||||||
|
|
||||||
|
def encode_value(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def decode_value(self, value_bytes):
|
||||||
|
return value_bytes
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
|
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
|
||||||
@@ -98,6 +104,7 @@ class CharacteristicProxy(AttributeProxy):
|
|||||||
self.properties = properties
|
self.properties = properties
|
||||||
self.descriptors = []
|
self.descriptors = []
|
||||||
self.descriptors_discovered = False
|
self.descriptors_discovered = False
|
||||||
|
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||||
|
|
||||||
def get_descriptor(self, descriptor_type):
|
def get_descriptor(self, descriptor_type):
|
||||||
for descriptor in self.descriptors:
|
for descriptor in self.descriptors:
|
||||||
@@ -108,9 +115,25 @@ class CharacteristicProxy(AttributeProxy):
|
|||||||
return await self.client.discover_descriptors(self)
|
return await self.client.discover_descriptors(self)
|
||||||
|
|
||||||
async def subscribe(self, subscriber=None):
|
async def subscribe(self, subscriber=None):
|
||||||
|
if subscriber is not None:
|
||||||
|
if subscriber in self.subscribers:
|
||||||
|
# We already have a proxy subscriber
|
||||||
|
subscriber = self.subscribers[subscriber]
|
||||||
|
else:
|
||||||
|
# Create and register a proxy that will decode the value
|
||||||
|
original_subscriber = subscriber
|
||||||
|
|
||||||
|
def on_change(value):
|
||||||
|
original_subscriber(self.decode_value(value))
|
||||||
|
self.subscribers[subscriber] = on_change
|
||||||
|
subscriber = on_change
|
||||||
|
|
||||||
return await self.client.subscribe(self, subscriber)
|
return await self.client.subscribe(self, subscriber)
|
||||||
|
|
||||||
async def unsubscribe(self, subscriber=None):
|
async def unsubscribe(self, subscriber=None):
|
||||||
|
if subscriber in self.subscribers:
|
||||||
|
subscriber = self.subscribers.pop(subscriber)
|
||||||
|
|
||||||
return await self.client.unsubscribe(self, subscriber)
|
return await self.client.unsubscribe(self, subscriber)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -140,7 +163,6 @@ class ProfileServiceProxy:
|
|||||||
class Client:
|
class Client:
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.mtu = ATT_DEFAULT_MTU
|
|
||||||
self.mtu_exchange_done = False
|
self.mtu_exchange_done = False
|
||||||
self.request_semaphore = asyncio.Semaphore(1)
|
self.request_semaphore = asyncio.Semaphore(1)
|
||||||
self.pending_request = None
|
self.pending_request = None
|
||||||
@@ -194,7 +216,7 @@ class Client:
|
|||||||
|
|
||||||
# We can only send one request per connection
|
# We can only send one request per connection
|
||||||
if self.mtu_exchange_done:
|
if self.mtu_exchange_done:
|
||||||
return
|
return self.connection.att_mtu
|
||||||
|
|
||||||
# Send the request
|
# Send the request
|
||||||
self.mtu_exchange_done = True
|
self.mtu_exchange_done = True
|
||||||
@@ -207,8 +229,10 @@ class Client:
|
|||||||
response
|
response
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
|
# Compute the final MTU
|
||||||
return self.mtu
|
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
|
||||||
|
|
||||||
|
return self.connection.att_mtu
|
||||||
|
|
||||||
def get_services_by_uuid(self, uuid):
|
def get_services_by_uuid(self, uuid):
|
||||||
return [service for service in self.services if service.uuid == uuid]
|
return [service for service in self.services if service.uuid == uuid]
|
||||||
@@ -570,12 +594,18 @@ class Client:
|
|||||||
subscribers = subscriber_set.get(characteristic.handle, [])
|
subscribers = subscriber_set.get(characteristic.handle, [])
|
||||||
if subscriber in subscribers:
|
if subscriber in subscribers:
|
||||||
subscribers.remove(subscriber)
|
subscribers.remove(subscriber)
|
||||||
|
|
||||||
|
# Cleanup if we removed the last one
|
||||||
|
if not subscribers:
|
||||||
|
subscriber_set.remove(characteristic.handle)
|
||||||
else:
|
else:
|
||||||
# Remove all subscribers for this attribute from the sets!
|
# Remove all subscribers for this attribute from the sets!
|
||||||
self.notification_subscribers.pop(characteristic.handle, None)
|
self.notification_subscribers.pop(characteristic.handle, None)
|
||||||
self.indication_subscribers.pop(characteristic.handle, None)
|
self.indication_subscribers.pop(characteristic.handle, None)
|
||||||
|
|
||||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
if not self.notification_subscribers and not self.indication_subscribers:
|
||||||
|
# No more subscribers left
|
||||||
|
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||||
|
|
||||||
async def read_value(self, attribute, no_long_read=False):
|
async def read_value(self, attribute, no_long_read=False):
|
||||||
'''
|
'''
|
||||||
@@ -600,7 +630,7 @@ class Client:
|
|||||||
# If the value is the max size for the MTU, try to read more unless the caller
|
# If the value is the max size for the MTU, try to read more unless the caller
|
||||||
# specifically asked not to do that
|
# specifically asked not to do that
|
||||||
attribute_value = response.attribute_value
|
attribute_value = response.attribute_value
|
||||||
if not no_long_read and len(attribute_value) == self.mtu - 1:
|
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
|
||||||
logger.debug('using READ BLOB to get the rest of the value')
|
logger.debug('using READ BLOB to get the rest of the value')
|
||||||
offset = len(attribute_value)
|
offset = len(attribute_value)
|
||||||
while True:
|
while True:
|
||||||
@@ -622,7 +652,7 @@ class Client:
|
|||||||
part = response.part_attribute_value
|
part = response.part_attribute_value
|
||||||
attribute_value += part
|
attribute_value += part
|
||||||
|
|
||||||
if len(part) < self.mtu - 1:
|
if len(part) < self.connection.att_mtu - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
offset += len(part)
|
offset += len(part)
|
||||||
|
|||||||
+47
-59
@@ -40,6 +40,12 @@ from .gatt import *
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# GATT Server
|
# GATT Server
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -49,9 +55,8 @@ class Server(EventEmitter):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.attributes = [] # Attributes, ordered by increasing handle values
|
self.attributes = [] # Attributes, ordered by increasing handle values
|
||||||
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
||||||
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
|
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
|
||||||
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
|
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
|
||||||
self.mtus = {} # Map of ATT MTU values by connection handle
|
|
||||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||||
self.pending_confirmations = defaultdict(lambda: None)
|
self.pending_confirmations = defaultdict(lambda: None)
|
||||||
|
|
||||||
@@ -169,7 +174,7 @@ class Server(EventEmitter):
|
|||||||
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
|
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
|
||||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||||
|
|
||||||
async def notify_subscriber(self, connection, attribute, force=False):
|
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
if not force:
|
if not force:
|
||||||
subscribers = self.subscribers.get(connection.handle)
|
subscribers = self.subscribers.get(connection.handle)
|
||||||
@@ -184,13 +189,12 @@ class Server(EventEmitter):
|
|||||||
logger.debug(f'not notifying, cccd={cccd.hex()}')
|
logger.debug(f'not notifying, cccd={cccd.hex()}')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the value
|
# Get or encode the value
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
mtu = self.get_mtu(connection)
|
if len(value) > connection.att_mtu - 3:
|
||||||
if len(value) > mtu - 3:
|
value = value[:connection.att_mtu - 3]
|
||||||
value = value[:mtu - 3]
|
|
||||||
|
|
||||||
# Notify
|
# Notify
|
||||||
notification = ATT_Handle_Value_Notification(
|
notification = ATT_Handle_Value_Notification(
|
||||||
@@ -198,27 +202,9 @@ class Server(EventEmitter):
|
|||||||
attribute_value = value
|
attribute_value = value
|
||||||
)
|
)
|
||||||
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
|
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
|
||||||
self.send_gatt_pdu(connection.handle, notification.to_bytes())
|
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||||
|
|
||||||
async def notify_subscribers(self, attribute, force=False):
|
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||||
# Get all the connections for which there's at least one subscription
|
|
||||||
connections = [
|
|
||||||
connection for connection in [
|
|
||||||
self.device.lookup_connection(connection_handle)
|
|
||||||
for (connection_handle, subscribers) in self.subscribers.items()
|
|
||||||
if force or subscribers.get(attribute.handle)
|
|
||||||
]
|
|
||||||
if connection is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Notify for each connection
|
|
||||||
if connections:
|
|
||||||
await asyncio.wait([
|
|
||||||
self.notify_subscriber(connection, attribute, force)
|
|
||||||
for connection in connections
|
|
||||||
])
|
|
||||||
|
|
||||||
async def indicate_subscriber(self, connection, attribute, force=False):
|
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
if not force:
|
if not force:
|
||||||
subscribers = self.subscribers.get(connection.handle)
|
subscribers = self.subscribers.get(connection.handle)
|
||||||
@@ -233,13 +219,12 @@ class Server(EventEmitter):
|
|||||||
logger.debug(f'not indicating, cccd={cccd.hex()}')
|
logger.debug(f'not indicating, cccd={cccd.hex()}')
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the value
|
# Get or encode the value
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
mtu = self.get_mtu(connection)
|
if len(value) > connection.att_mtu - 3:
|
||||||
if len(value) > mtu - 3:
|
value = value[:connection.att_mtu - 3]
|
||||||
value = value[:mtu - 3]
|
|
||||||
|
|
||||||
# Indicate
|
# Indicate
|
||||||
indication = ATT_Handle_Value_Indication(
|
indication = ATT_Handle_Value_Indication(
|
||||||
@@ -264,27 +249,32 @@ class Server(EventEmitter):
|
|||||||
finally:
|
finally:
|
||||||
self.pending_confirmations[connection.handle] = None
|
self.pending_confirmations[connection.handle] = None
|
||||||
|
|
||||||
async def indicate_subscribers(self, attribute):
|
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
|
||||||
# Get all the connections for which there's at least one subscription
|
# Get all the connections for which there's at least one subscription
|
||||||
connections = [
|
connections = [
|
||||||
connection for connection in [
|
connection for connection in [
|
||||||
self.device.lookup_connection(connection_handle)
|
self.device.lookup_connection(connection_handle)
|
||||||
for (connection_handle, subscribers) in self.subscribers.items()
|
for (connection_handle, subscribers) in self.subscribers.items()
|
||||||
if subscribers.get(attribute.handle)
|
if force or subscribers.get(attribute.handle)
|
||||||
]
|
]
|
||||||
if connection is not None
|
if connection is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
# Indicate for each connection
|
# Indicate or notify for each connection
|
||||||
if connections:
|
if connections:
|
||||||
|
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
self.indicate_subscriber(connection, attribute)
|
asyncio.create_task(coroutine(connection, attribute, value, force))
|
||||||
for connection in connections
|
for connection in connections
|
||||||
])
|
])
|
||||||
|
|
||||||
|
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||||
|
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||||
|
|
||||||
|
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||||
|
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||||
|
|
||||||
def on_disconnection(self, connection):
|
def on_disconnection(self, connection):
|
||||||
if connection.handle in self.mtus:
|
|
||||||
del self.mtus[connection.handle]
|
|
||||||
if connection.handle in self.subscribers:
|
if connection.handle in self.subscribers:
|
||||||
del self.subscribers[connection.handle]
|
del self.subscribers[connection.handle]
|
||||||
if connection.handle in self.indication_semaphores:
|
if connection.handle in self.indication_semaphores:
|
||||||
@@ -325,9 +315,6 @@ class Server(EventEmitter):
|
|||||||
# Just ignore
|
# Just ignore
|
||||||
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
|
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
|
||||||
|
|
||||||
def get_mtu(self, connection):
|
|
||||||
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# ATT handlers
|
# ATT handlers
|
||||||
#######################################################
|
#######################################################
|
||||||
@@ -347,12 +334,16 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
||||||
'''
|
'''
|
||||||
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
|
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
|
||||||
self.mtus[connection.handle] = mtu
|
|
||||||
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
|
|
||||||
|
|
||||||
# Notify the device
|
# Compute the final MTU
|
||||||
self.device.on_connection_att_mtu_update(connection.handle, mtu)
|
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
|
||||||
|
mtu = min(self.max_mtu, request.client_rx_mtu)
|
||||||
|
|
||||||
|
# Notify the device
|
||||||
|
self.device.on_connection_att_mtu_update(connection.handle, mtu)
|
||||||
|
else:
|
||||||
|
logger.warning('invalid client_rx_mtu received, MTU not changed')
|
||||||
|
|
||||||
def on_att_find_information_request(self, connection, request):
|
def on_att_find_information_request(self, connection, request):
|
||||||
'''
|
'''
|
||||||
@@ -369,7 +360,7 @@ class Server(EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Build list of returned attributes
|
# Build list of returned attributes
|
||||||
pdu_space_available = self.get_mtu(connection) - 2
|
pdu_space_available = connection.att_mtu - 2
|
||||||
attributes = []
|
attributes = []
|
||||||
uuid_size = 0
|
uuid_size = 0
|
||||||
for attribute in (
|
for attribute in (
|
||||||
@@ -420,7 +411,7 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
# Build list of returned attributes
|
# Build list of returned attributes
|
||||||
pdu_space_available = self.get_mtu(connection) - 2
|
pdu_space_available = connection.att_mtu - 2
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -468,8 +459,7 @@ class Server(EventEmitter):
|
|||||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||||
'''
|
'''
|
||||||
|
|
||||||
mtu = self.get_mtu(connection)
|
pdu_space_available = connection.att_mtu - 2
|
||||||
pdu_space_available = mtu - 2
|
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -482,7 +472,7 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
# Check the attribute value size
|
# Check the attribute value size
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
max_attribute_size = min(mtu - 4, 253)
|
max_attribute_size = min(connection.att_mtu - 4, 253)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
attribute_value = attribute_value[:max_attribute_size]
|
attribute_value = attribute_value[:max_attribute_size]
|
||||||
@@ -522,7 +512,7 @@ class Server(EventEmitter):
|
|||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
# TODO: check permissions
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
value_size = min(self.get_mtu(connection) - 1, len(value))
|
value_size = min(connection.att_mtu - 1, len(value))
|
||||||
response = ATT_Read_Response(
|
response = ATT_Read_Response(
|
||||||
attribute_value = value[:value_size]
|
attribute_value = value[:value_size]
|
||||||
)
|
)
|
||||||
@@ -541,7 +531,6 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
# TODO: check permissions
|
||||||
mtu = self.get_mtu(connection)
|
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
if request.value_offset > len(value):
|
if request.value_offset > len(value):
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
@@ -549,14 +538,14 @@ class Server(EventEmitter):
|
|||||||
attribute_handle_in_error = request.attribute_handle,
|
attribute_handle_in_error = request.attribute_handle,
|
||||||
error_code = ATT_INVALID_OFFSET_ERROR
|
error_code = ATT_INVALID_OFFSET_ERROR
|
||||||
)
|
)
|
||||||
elif len(value) <= mtu - 1:
|
elif len(value) <= connection.att_mtu - 1:
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error = request.op_code,
|
request_opcode_in_error = request.op_code,
|
||||||
attribute_handle_in_error = request.attribute_handle,
|
attribute_handle_in_error = request.attribute_handle,
|
||||||
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
|
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
part_size = min(mtu - 1, len(value) - request.value_offset)
|
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
|
||||||
response = ATT_Read_Blob_Response(
|
response = ATT_Read_Blob_Response(
|
||||||
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
|
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
|
||||||
)
|
)
|
||||||
@@ -585,8 +574,7 @@ class Server(EventEmitter):
|
|||||||
self.send_response(connection, response)
|
self.send_response(connection, response)
|
||||||
return
|
return
|
||||||
|
|
||||||
mtu = self.get_mtu(connection)
|
pdu_space_available = connection.att_mtu - 2
|
||||||
pdu_space_available = mtu - 2
|
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -597,7 +585,7 @@ class Server(EventEmitter):
|
|||||||
):
|
):
|
||||||
# Check the attribute value size
|
# Check the attribute value size
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
max_attribute_size = min(mtu - 6, 251)
|
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
attribute_value = attribute_value[:max_attribute_size]
|
attribute_value = attribute_value[:max_attribute_size]
|
||||||
|
|||||||
+246
-76
@@ -22,6 +22,7 @@ import struct
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from bumble.controller import Controller
|
from bumble.controller import Controller
|
||||||
|
from bumble.gatt_client import CharacteristicProxy
|
||||||
from bumble.link import LocalLink
|
from bumble.link import LocalLink
|
||||||
from bumble.device import Device, Peer
|
from bumble.device import Device, Peer
|
||||||
from bumble.host import Host
|
from bumble.host import Host
|
||||||
@@ -53,29 +54,29 @@ def basic_check(x):
|
|||||||
parsed = ATT_PDU.from_bytes(pdu)
|
parsed = ATT_PDU.from_bytes(pdu)
|
||||||
x_str = str(x)
|
x_str = str(x)
|
||||||
parsed_str = str(parsed)
|
parsed_str = str(parsed)
|
||||||
assert(x_str == parsed_str)
|
assert x_str == parsed_str
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_UUID():
|
def test_UUID():
|
||||||
u = UUID.from_16_bits(0x7788)
|
u = UUID.from_16_bits(0x7788)
|
||||||
assert(str(u) == 'UUID-16:7788')
|
assert str(u) == 'UUID-16:7788'
|
||||||
u = UUID.from_32_bits(0x11223344)
|
u = UUID.from_32_bits(0x11223344)
|
||||||
assert(str(u) == 'UUID-32:11223344')
|
assert str(u) == 'UUID-32:11223344'
|
||||||
u = UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
u = UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
||||||
assert(str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||||
v = UUID(str(u))
|
v = UUID(str(u))
|
||||||
assert(str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||||
w = UUID.from_bytes(v.to_bytes())
|
w = UUID.from_bytes(v.to_bytes())
|
||||||
assert(str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||||
|
|
||||||
u1 = UUID.from_16_bits(0x1234)
|
u1 = UUID.from_16_bits(0x1234)
|
||||||
b1 = u1.to_bytes(force_128 = True)
|
b1 = u1.to_bytes(force_128 = True)
|
||||||
u2 = UUID.from_bytes(b1)
|
u2 = UUID.from_bytes(b1)
|
||||||
assert(u1 == u2)
|
assert u1 == u2
|
||||||
|
|
||||||
u3 = UUID.from_16_bits(0x180a)
|
u3 = UUID.from_16_bits(0x180a)
|
||||||
assert(str(u3) == 'UUID-16:180A (Device Information)')
|
assert str(u3) == 'UUID-16:180A (Device Information)'
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -98,6 +99,122 @@ def test_ATT_Read_By_Group_Type_Request():
|
|||||||
basic_check(pdu)
|
basic_check(pdu)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_characteristic_encoding():
|
||||||
|
class Foo(Characteristic):
|
||||||
|
def encode_value(self, value):
|
||||||
|
return bytes([value])
|
||||||
|
|
||||||
|
def decode_value(self, value_bytes):
|
||||||
|
return value_bytes[0]
|
||||||
|
|
||||||
|
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123)
|
||||||
|
x = c.read_value(None)
|
||||||
|
assert x == bytes([123])
|
||||||
|
c.write_value(None, bytes([122]))
|
||||||
|
assert c.value == 122
|
||||||
|
|
||||||
|
class FooProxy(CharacteristicProxy):
|
||||||
|
def __init__(self, characteristic):
|
||||||
|
super().__init__(
|
||||||
|
characteristic.client,
|
||||||
|
characteristic.handle,
|
||||||
|
characteristic.end_group_handle,
|
||||||
|
characteristic.uuid,
|
||||||
|
characteristic.properties
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_value(self, value):
|
||||||
|
return bytes([value])
|
||||||
|
|
||||||
|
def decode_value(self, value_bytes):
|
||||||
|
return value_bytes[0]
|
||||||
|
|
||||||
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
|
characteristic = Characteristic(
|
||||||
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
|
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
|
||||||
|
Characteristic.READABLE | Characteristic.WRITEABLE,
|
||||||
|
bytes([123])
|
||||||
|
)
|
||||||
|
|
||||||
|
service = Service(
|
||||||
|
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
|
||||||
|
[characteristic]
|
||||||
|
)
|
||||||
|
server.add_service(service)
|
||||||
|
|
||||||
|
await client.power_on()
|
||||||
|
await server.power_on()
|
||||||
|
connection = await client.connect(server.random_address)
|
||||||
|
peer = Peer(connection)
|
||||||
|
|
||||||
|
await peer.discover_services()
|
||||||
|
await peer.discover_characteristics()
|
||||||
|
c = peer.get_characteristics_by_uuid(characteristic.uuid)
|
||||||
|
assert len(c) == 1
|
||||||
|
c = c[0]
|
||||||
|
cp = FooProxy(c)
|
||||||
|
|
||||||
|
v = await cp.read_value()
|
||||||
|
assert v == 123
|
||||||
|
await cp.write_value(124)
|
||||||
|
await async_barrier()
|
||||||
|
assert characteristic.value == bytes([124])
|
||||||
|
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
def on_change(value):
|
||||||
|
nonlocal last_change
|
||||||
|
last_change = value
|
||||||
|
|
||||||
|
await c.subscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change == characteristic.value
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
await server.notify_subscribers(characteristic, value=bytes([125]))
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change == bytes([125])
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
await c.unsubscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change is None
|
||||||
|
|
||||||
|
await cp.subscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change == characteristic.value[0]
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
await server.notify_subscribers(characteristic, value=bytes([126]))
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change == 126
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
await cp.unsubscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change is None
|
||||||
|
|
||||||
|
cd = DelegatedCharacteristicAdapter(c, decode=lambda x: x[0])
|
||||||
|
await cd.subscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change == characteristic.value[0]
|
||||||
|
last_change = None
|
||||||
|
|
||||||
|
await cd.unsubscribe(on_change)
|
||||||
|
await server.notify_subscribers(characteristic)
|
||||||
|
await async_barrier()
|
||||||
|
assert last_change is None
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_CharacteristicAdapter():
|
def test_CharacteristicAdapter():
|
||||||
# Check that the CharacteristicAdapter base class is transparent
|
# Check that the CharacteristicAdapter base class is transparent
|
||||||
@@ -106,21 +223,21 @@ def test_CharacteristicAdapter():
|
|||||||
a = CharacteristicAdapter(c)
|
a = CharacteristicAdapter(c)
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == v)
|
assert value == v
|
||||||
|
|
||||||
v = bytes([3, 4, 5])
|
v = bytes([3, 4, 5])
|
||||||
a.write_value(None, v)
|
a.write_value(None, v)
|
||||||
assert(c.value == v)
|
assert c.value == v
|
||||||
|
|
||||||
# Simple delegated adapter
|
# Simple delegated adapter
|
||||||
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
|
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == bytes(reversed(v)))
|
assert value == bytes(reversed(v))
|
||||||
|
|
||||||
v = bytes([3, 4, 5])
|
v = bytes([3, 4, 5])
|
||||||
a.write_value(None, v)
|
a.write_value(None, v)
|
||||||
assert(a.value == bytes(reversed(v)))
|
assert a.value == bytes(reversed(v))
|
||||||
|
|
||||||
# Packed adapter with single element format
|
# Packed adapter with single element format
|
||||||
v = 1234
|
v = 1234
|
||||||
@@ -129,10 +246,10 @@ def test_CharacteristicAdapter():
|
|||||||
a = PackedCharacteristicAdapter(c, '>H')
|
a = PackedCharacteristicAdapter(c, '>H')
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == pv)
|
assert value == pv
|
||||||
c.value = None
|
c.value = None
|
||||||
a.write_value(None, pv)
|
a.write_value(None, pv)
|
||||||
assert(a.value == v)
|
assert a.value == v
|
||||||
|
|
||||||
# Packed adapter with multi-element format
|
# Packed adapter with multi-element format
|
||||||
v1 = 1234
|
v1 = 1234
|
||||||
@@ -142,10 +259,10 @@ def test_CharacteristicAdapter():
|
|||||||
a = PackedCharacteristicAdapter(c, '>HH')
|
a = PackedCharacteristicAdapter(c, '>HH')
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == pv)
|
assert value == pv
|
||||||
c.value = None
|
c.value = None
|
||||||
a.write_value(None, pv)
|
a.write_value(None, pv)
|
||||||
assert(a.value == (v1, v2))
|
assert a.value == (v1, v2)
|
||||||
|
|
||||||
# Mapped adapter
|
# Mapped adapter
|
||||||
v1 = 1234
|
v1 = 1234
|
||||||
@@ -156,10 +273,10 @@ def test_CharacteristicAdapter():
|
|||||||
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
|
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == pv)
|
assert value == pv
|
||||||
c.value = None
|
c.value = None
|
||||||
a.write_value(None, pv)
|
a.write_value(None, pv)
|
||||||
assert(a.value == mapped)
|
assert a.value == mapped
|
||||||
|
|
||||||
# UTF-8 adapter
|
# UTF-8 adapter
|
||||||
v = 'Hello π'
|
v = 'Hello π'
|
||||||
@@ -168,10 +285,10 @@ def test_CharacteristicAdapter():
|
|||||||
a = UTF8CharacteristicAdapter(c)
|
a = UTF8CharacteristicAdapter(c)
|
||||||
|
|
||||||
value = a.read_value(None)
|
value = a.read_value(None)
|
||||||
assert(value == ev)
|
assert value == ev
|
||||||
c.value = None
|
c.value = None
|
||||||
a.write_value(None, ev)
|
a.write_value(None, ev)
|
||||||
assert(a.value == v)
|
assert a.value == v
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -179,24 +296,25 @@ def test_CharacteristicValue():
|
|||||||
b = bytes([1, 2, 3])
|
b = bytes([1, 2, 3])
|
||||||
c = CharacteristicValue(read=lambda _: b)
|
c = CharacteristicValue(read=lambda _: b)
|
||||||
x = c.read(None)
|
x = c.read(None)
|
||||||
assert(x == b)
|
assert x == b
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
|
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
|
||||||
z = object()
|
z = object()
|
||||||
c.write(z, b)
|
c.write(z, b)
|
||||||
assert(result == [(z, b)])
|
assert result == [(z, b)]
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class TwoDevices:
|
class LinkedDevices:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.connections = [None, None]
|
self.connections = [None, None, None]
|
||||||
|
|
||||||
self.link = LocalLink()
|
self.link = LocalLink()
|
||||||
self.controllers = [
|
self.controllers = [
|
||||||
Controller('C1', link = self.link),
|
Controller('C1', link = self.link),
|
||||||
Controller('C2', link = self.link)
|
Controller('C2', link = self.link),
|
||||||
|
Controller('C3', link = self.link)
|
||||||
]
|
]
|
||||||
self.devices = [
|
self.devices = [
|
||||||
Device(
|
Device(
|
||||||
@@ -204,12 +322,16 @@ class TwoDevices:
|
|||||||
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
|
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
|
||||||
),
|
),
|
||||||
Device(
|
Device(
|
||||||
address = 'F5:F4:F3:F2:F1:F0',
|
address = 'F1:F2:F3:F4:F5:F6',
|
||||||
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
|
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
|
||||||
|
),
|
||||||
|
Device(
|
||||||
|
address = 'F2:F3:F4:F5:F6:F7',
|
||||||
|
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.paired = [None, None]
|
self.paired = [None, None, None]
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -222,7 +344,7 @@ async def async_barrier():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_read_write():
|
async def test_read_write():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
@@ -265,41 +387,41 @@ async def test_read_write():
|
|||||||
await peer.discover_services()
|
await peer.discover_services()
|
||||||
await peer.discover_characteristics()
|
await peer.discover_characteristics()
|
||||||
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c1 = c[0]
|
c1 = c[0]
|
||||||
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c2 = c[0]
|
c2 = c[0]
|
||||||
|
|
||||||
v1 = await peer.read_value(c1)
|
v1 = await peer.read_value(c1)
|
||||||
assert(v1 == b'')
|
assert v1 == b''
|
||||||
b = bytes([1, 2, 3])
|
b = bytes([1, 2, 3])
|
||||||
await peer.write_value(c1, b)
|
await peer.write_value(c1, b)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(characteristic1.value == b)
|
assert characteristic1.value == b
|
||||||
v1 = await peer.read_value(c1)
|
v1 = await peer.read_value(c1)
|
||||||
assert(v1 == b)
|
assert v1 == b
|
||||||
assert(type(characteristic1._last_value) is tuple)
|
assert type(characteristic1._last_value is tuple)
|
||||||
assert(len(characteristic1._last_value) == 2)
|
assert len(characteristic1._last_value) == 2
|
||||||
assert(str(characteristic1._last_value[0].peer_address) == str(client.random_address))
|
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address)
|
||||||
assert(characteristic1._last_value[1] == b)
|
assert characteristic1._last_value[1] == b
|
||||||
bb = bytes([3, 4, 5, 6])
|
bb = bytes([3, 4, 5, 6])
|
||||||
characteristic1.value = bb
|
characteristic1.value = bb
|
||||||
v1 = await peer.read_value(c1)
|
v1 = await peer.read_value(c1)
|
||||||
assert(v1 == bb)
|
assert v1 == bb
|
||||||
|
|
||||||
await peer.write_value(c2, b)
|
await peer.write_value(c2, b)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(type(characteristic2._last_value) is tuple)
|
assert type(characteristic2._last_value is tuple)
|
||||||
assert(len(characteristic2._last_value) == 2)
|
assert len(characteristic2._last_value) == 2
|
||||||
assert(str(characteristic2._last_value[0].peer_address) == str(client.random_address))
|
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address)
|
||||||
assert(characteristic2._last_value[1] == b)
|
assert characteristic2._last_value[1] == b
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_read_write2():
|
async def test_read_write2():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
v = bytes([0x11, 0x22, 0x33, 0x44])
|
v = bytes([0x11, 0x22, 0x33, 0x44])
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
@@ -324,32 +446,32 @@ async def test_read_write2():
|
|||||||
|
|
||||||
await peer.discover_services()
|
await peer.discover_services()
|
||||||
c = peer.get_services_by_uuid(service1.uuid)
|
c = peer.get_services_by_uuid(service1.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
s = c[0]
|
s = c[0]
|
||||||
await s.discover_characteristics()
|
await s.discover_characteristics()
|
||||||
c = s.get_characteristics_by_uuid(characteristic1.uuid)
|
c = s.get_characteristics_by_uuid(characteristic1.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c1 = c[0]
|
c1 = c[0]
|
||||||
|
|
||||||
v1 = await c1.read_value()
|
v1 = await c1.read_value()
|
||||||
assert(v1 == v)
|
assert v1 == v
|
||||||
|
|
||||||
a1 = PackedCharacteristicAdapter(c1, '>I')
|
a1 = PackedCharacteristicAdapter(c1, '>I')
|
||||||
v1 = await a1.read_value()
|
v1 = await a1.read_value()
|
||||||
assert(v1 == struct.unpack('>I', v)[0])
|
assert v1 == struct.unpack('>I', v)[0]
|
||||||
|
|
||||||
b = bytes([0x55, 0x66, 0x77, 0x88])
|
b = bytes([0x55, 0x66, 0x77, 0x88])
|
||||||
await a1.write_value(struct.unpack('>I', b)[0])
|
await a1.write_value(struct.unpack('>I', b)[0])
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(characteristic1.value == b)
|
assert characteristic1.value == b
|
||||||
v1 = await a1.read_value()
|
v1 = await a1.read_value()
|
||||||
assert(v1 == struct.unpack('>I', b)[0])
|
assert v1 == struct.unpack('>I', b)[0]
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscribe_notify():
|
async def test_subscribe_notify():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
@@ -410,13 +532,13 @@ async def test_subscribe_notify():
|
|||||||
await peer.discover_services()
|
await peer.discover_services()
|
||||||
await peer.discover_characteristics()
|
await peer.discover_characteristics()
|
||||||
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c1 = c[0]
|
c1 = c[0]
|
||||||
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c2 = c[0]
|
c2 = c[0]
|
||||||
c = peer.get_characteristics_by_uuid(characteristic3.uuid)
|
c = peer.get_characteristics_by_uuid(characteristic3.uuid)
|
||||||
assert(len(c) == 1)
|
assert len(c) == 1
|
||||||
c3 = c[0]
|
c3 = c[0]
|
||||||
|
|
||||||
c1._called = False
|
c1._called = False
|
||||||
@@ -429,23 +551,32 @@ async def test_subscribe_notify():
|
|||||||
c1.on('update', on_c1_update)
|
c1.on('update', on_c1_update)
|
||||||
await peer.subscribe(c1)
|
await peer.subscribe(c1)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(server._last_subscription[1] == characteristic1)
|
assert server._last_subscription[1] == characteristic1
|
||||||
assert(server._last_subscription[2])
|
assert server._last_subscription[2]
|
||||||
assert(not server._last_subscription[3])
|
assert not server._last_subscription[3]
|
||||||
assert(characteristic1._last_subscription[1])
|
assert characteristic1._last_subscription[1]
|
||||||
assert(not characteristic1._last_subscription[2])
|
assert not characteristic1._last_subscription[2]
|
||||||
await server.indicate_subscribers(characteristic1)
|
await server.indicate_subscribers(characteristic1)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(not c1._called)
|
assert not c1._called
|
||||||
await server.notify_subscribers(characteristic1)
|
await server.notify_subscribers(characteristic1)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(c1._called)
|
assert c1._called
|
||||||
assert(c1._last_update == characteristic1.value)
|
assert c1._last_update == characteristic1.value
|
||||||
|
|
||||||
|
c1._called = False
|
||||||
|
c1._last_update = None
|
||||||
|
c1_value = characteristic1.value
|
||||||
|
await server.notify_subscribers(characteristic1, bytes([0, 1, 2]))
|
||||||
|
await async_barrier()
|
||||||
|
assert c1._called
|
||||||
|
assert c1._last_update == bytes([0, 1, 2])
|
||||||
|
assert characteristic1.value == c1_value
|
||||||
|
|
||||||
c1._called = False
|
c1._called = False
|
||||||
await peer.unsubscribe(c1)
|
await peer.unsubscribe(c1)
|
||||||
await server.notify_subscribers(characteristic1)
|
await server.notify_subscribers(characteristic1)
|
||||||
assert(not c1._called)
|
assert not c1._called
|
||||||
|
|
||||||
c2._called = False
|
c2._called = False
|
||||||
c2._last_update = None
|
c2._last_update = None
|
||||||
@@ -458,17 +589,17 @@ async def test_subscribe_notify():
|
|||||||
await async_barrier()
|
await async_barrier()
|
||||||
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
|
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(not c2._called)
|
assert not c2._called
|
||||||
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(c2._called)
|
assert c2._called
|
||||||
assert(c2._last_update == characteristic2.value)
|
assert c2._last_update == characteristic2.value
|
||||||
|
|
||||||
c2._called = False
|
c2._called = False
|
||||||
await peer.unsubscribe(c2, on_c2_update)
|
await peer.unsubscribe(c2, on_c2_update)
|
||||||
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(not c2._called)
|
assert not c2._called
|
||||||
|
|
||||||
def on_c3_update(value):
|
def on_c3_update(value):
|
||||||
c3._called = True
|
c3._called = True
|
||||||
@@ -483,17 +614,17 @@ async def test_subscribe_notify():
|
|||||||
await async_barrier()
|
await async_barrier()
|
||||||
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(c3._called)
|
assert c3._called
|
||||||
assert(c3._last_update == characteristic3.value)
|
assert c3._last_update == characteristic3.value
|
||||||
assert(c3._called_2)
|
assert c3._called_2
|
||||||
assert(c3._last_update_2 == characteristic3.value)
|
assert c3._last_update_2 == characteristic3.value
|
||||||
characteristic3.value = bytes([1, 2, 3])
|
characteristic3.value = bytes([1, 2, 3])
|
||||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(c3._called)
|
assert c3._called
|
||||||
assert(c3._last_update == characteristic3.value)
|
assert c3._last_update == characteristic3.value
|
||||||
assert(c3._called_2)
|
assert c3._called_2
|
||||||
assert(c3._last_update_2 == characteristic3.value)
|
assert c3._last_update_2 == characteristic3.value
|
||||||
|
|
||||||
c3._called = False
|
c3._called = False
|
||||||
c3._called_2 = False
|
c3._called_2 = False
|
||||||
@@ -501,8 +632,44 @@ async def test_subscribe_notify():
|
|||||||
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||||
await async_barrier()
|
await async_barrier()
|
||||||
assert(not c3._called)
|
assert not c3._called
|
||||||
assert(not c3._called_2)
|
assert not c3._called_2
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mtu_exchange():
|
||||||
|
[d1, d2, d3] = LinkedDevices().devices[:3]
|
||||||
|
|
||||||
|
d3.gatt_server.max_mtu = 100
|
||||||
|
|
||||||
|
d3_connections = []
|
||||||
|
@d3.on('connection')
|
||||||
|
def on_d3_connection(connection):
|
||||||
|
d3_connections.append(connection)
|
||||||
|
|
||||||
|
await d1.power_on()
|
||||||
|
await d2.power_on()
|
||||||
|
await d3.power_on()
|
||||||
|
|
||||||
|
d1_connection = await d1.connect(d3.random_address)
|
||||||
|
assert len(d3_connections) == 1
|
||||||
|
assert d3_connections[0] is not None
|
||||||
|
|
||||||
|
d2_connection = await d2.connect(d3.random_address)
|
||||||
|
assert len(d3_connections) == 2
|
||||||
|
assert d3_connections[1] is not None
|
||||||
|
|
||||||
|
d1_peer = Peer(d1_connection)
|
||||||
|
d2_peer = Peer(d2_connection)
|
||||||
|
|
||||||
|
d1_client_mtu = await d1_peer.request_mtu(220)
|
||||||
|
assert d1_client_mtu == 100
|
||||||
|
assert d1_connection.att_mtu == 100
|
||||||
|
|
||||||
|
d2_client_mtu = await d2_peer.request_mtu(50)
|
||||||
|
assert d2_client_mtu == 50
|
||||||
|
assert d2_connection.att_mtu == 50
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -510,6 +677,9 @@ async def async_main():
|
|||||||
await test_read_write()
|
await test_read_write()
|
||||||
await test_read_write2()
|
await test_read_write2()
|
||||||
await test_subscribe_notify()
|
await test_subscribe_notify()
|
||||||
|
await test_characteristic_encoding()
|
||||||
|
await test_mtu_exchange()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user