forked from auracaster/bumble_mirror
Compare commits
4 Commits
gbg/rpa
...
gbg/gatt-n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80569bc9f3 | ||
|
|
daa05b8996 | ||
|
|
624e860762 | ||
|
|
159cbf7774 |
@@ -700,16 +700,26 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def encode_value(self, value):
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection):
|
||||
if read := getattr(self.value, 'read', None):
|
||||
try:
|
||||
return read(connection)
|
||||
value = read(connection)
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
|
||||
else:
|
||||
return self.value
|
||||
value = self.value
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection, value_bytes):
|
||||
value = self.decode_value(value_bytes)
|
||||
|
||||
def write_value(self, connection, value):
|
||||
if write := getattr(self.value, 'write', None):
|
||||
try:
|
||||
write(connection, value)
|
||||
@@ -721,7 +731,11 @@ class Attribute(EventEmitter):
|
||||
self.emit('write', connection, value)
|
||||
|
||||
def __repr__(self):
|
||||
if len(self.value) > 0:
|
||||
if type(self.value) is bytes:
|
||||
value_str = self.value.hex()
|
||||
else:
|
||||
value_str = str(self.value)
|
||||
if value_str:
|
||||
value_string = f', value={self.value.hex()}'
|
||||
else:
|
||||
value_string = ''
|
||||
|
||||
@@ -1210,17 +1210,17 @@ class Device(CompositeEventEmitter):
|
||||
def add_services(self, services):
|
||||
self.gatt_server.add_services(services)
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, force=False):
|
||||
await self.gatt_server.notify_subscriber(connection, attribute, force)
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def notify_subscribers(self, attribute, force=False):
|
||||
await self.gatt_server.notify_subscribers(attribute, force)
|
||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||
await self.gatt_server.notify_subscribers(attribute, value, force)
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, force=False):
|
||||
await self.gatt_server.indicate_subscriber(connection, attribute, force)
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(self, attribute):
|
||||
await self.gatt_server.indicate_subscribers(attribute)
|
||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||
await self.gatt_server.indicate_subscribers(attribute, value, force)
|
||||
|
||||
@host_event_handler
|
||||
def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters):
|
||||
|
||||
@@ -303,6 +303,7 @@ class CharacteristicAdapter:
|
||||
'''
|
||||
def __init__(self, characteristic):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
|
||||
if (
|
||||
asyncio.iscoroutinefunction(characteristic.read_value) and
|
||||
@@ -317,11 +318,21 @@ class CharacteristicAdapter:
|
||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.wrapped_characteristic, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
|
||||
if name in {
|
||||
'wrapped_characteristic',
|
||||
'subscribers',
|
||||
'read_value',
|
||||
'write_value',
|
||||
'subscribe',
|
||||
'unsubscribe'
|
||||
}:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
@@ -345,9 +356,26 @@ class CharacteristicAdapter:
|
||||
return value
|
||||
|
||||
def wrapped_subscribe(self, subscriber=None):
|
||||
return self.wrapped_characteristic.subscribe(
|
||||
None if subscriber is None else lambda value: subscriber(self.decode_value(value))
|
||||
)
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
# We already have a proxy subscriber
|
||||
subscriber = self.subscribers[subscriber]
|
||||
else:
|
||||
# Create and register a proxy that will decode the value
|
||||
original_subscriber = subscriber
|
||||
|
||||
def on_change(value):
|
||||
original_subscriber(self.decode_value(value))
|
||||
self.subscribers[subscriber] = on_change
|
||||
subscriber = on_change
|
||||
|
||||
return self.wrapped_characteristic.subscribe(subscriber)
|
||||
|
||||
def wrapped_unsubscribe(self, subscriber=None):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||
|
||||
def __str__(self):
|
||||
wrapped = str(self.wrapped_characteristic)
|
||||
|
||||
@@ -58,10 +58,16 @@ class AttributeProxy(EventEmitter):
|
||||
self.type = attribute_type
|
||||
|
||||
async def read_value(self, no_long_read=False):
|
||||
return await self.client.read_value(self.handle, no_long_read)
|
||||
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
|
||||
|
||||
async def write_value(self, value, with_response=False):
|
||||
return await self.client.write_value(self.handle, value, with_response)
|
||||
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
|
||||
|
||||
def encode_value(self, value):
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
return value_bytes
|
||||
|
||||
def __str__(self):
|
||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
|
||||
@@ -98,6 +104,7 @@ class CharacteristicProxy(AttributeProxy):
|
||||
self.properties = properties
|
||||
self.descriptors = []
|
||||
self.descriptors_discovered = False
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
|
||||
def get_descriptor(self, descriptor_type):
|
||||
for descriptor in self.descriptors:
|
||||
@@ -108,9 +115,25 @@ class CharacteristicProxy(AttributeProxy):
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(self, subscriber=None):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
# We already have a proxy subscriber
|
||||
subscriber = self.subscribers[subscriber]
|
||||
else:
|
||||
# Create and register a proxy that will decode the value
|
||||
original_subscriber = subscriber
|
||||
|
||||
def on_change(value):
|
||||
original_subscriber(self.decode_value(value))
|
||||
self.subscribers[subscriber] = on_change
|
||||
subscriber = on_change
|
||||
|
||||
return await self.client.subscribe(self, subscriber)
|
||||
|
||||
async def unsubscribe(self, subscriber=None):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber)
|
||||
|
||||
def __str__(self):
|
||||
@@ -570,12 +593,18 @@ class Client:
|
||||
subscribers = subscriber_set.get(characteristic.handle, [])
|
||||
if subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
subscriber_set.remove(characteristic.handle)
|
||||
else:
|
||||
# Remove all subscribers for this attribute from the sets!
|
||||
self.notification_subscribers.pop(characteristic.handle, None)
|
||||
self.indication_subscribers.pop(characteristic.handle, None)
|
||||
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
if not self.notification_subscribers and not self.indication_subscribers:
|
||||
# No more subscribers left
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
|
||||
async def read_value(self, attribute, no_long_read=False):
|
||||
'''
|
||||
|
||||
@@ -169,7 +169,7 @@ class Server(EventEmitter):
|
||||
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, force=False):
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -184,8 +184,8 @@ class Server(EventEmitter):
|
||||
logger.debug(f'not notifying, cccd={cccd.hex()}')
|
||||
return
|
||||
|
||||
# Get the value
|
||||
value = attribute.read_value(connection)
|
||||
# Get or encode the value
|
||||
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||
|
||||
# Truncate if needed
|
||||
mtu = self.get_mtu(connection)
|
||||
@@ -198,27 +198,9 @@ class Server(EventEmitter):
|
||||
attribute_value = value
|
||||
)
|
||||
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
|
||||
self.send_gatt_pdu(connection.handle, notification.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
|
||||
async def notify_subscribers(self, attribute, force=False):
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection for connection in [
|
||||
self.device.lookup_connection(connection_handle)
|
||||
for (connection_handle, subscribers) in self.subscribers.items()
|
||||
if force or subscribers.get(attribute.handle)
|
||||
]
|
||||
if connection is not None
|
||||
]
|
||||
|
||||
# Notify for each connection
|
||||
if connections:
|
||||
await asyncio.wait([
|
||||
self.notify_subscriber(connection, attribute, force)
|
||||
for connection in connections
|
||||
])
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, force=False):
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -233,8 +215,8 @@ class Server(EventEmitter):
|
||||
logger.debug(f'not indicating, cccd={cccd.hex()}')
|
||||
return
|
||||
|
||||
# Get the value
|
||||
value = attribute.read_value(connection)
|
||||
# Get or encode the value
|
||||
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||
|
||||
# Truncate if needed
|
||||
mtu = self.get_mtu(connection)
|
||||
@@ -264,24 +246,31 @@ class Server(EventEmitter):
|
||||
finally:
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
|
||||
async def indicate_subscribers(self, attribute):
|
||||
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection for connection in [
|
||||
self.device.lookup_connection(connection_handle)
|
||||
for (connection_handle, subscribers) in self.subscribers.items()
|
||||
if subscribers.get(attribute.handle)
|
||||
if force or subscribers.get(attribute.handle)
|
||||
]
|
||||
if connection is not None
|
||||
]
|
||||
|
||||
# Indicate for each connection
|
||||
# Indicate or notify for each connection
|
||||
if connections:
|
||||
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
|
||||
await asyncio.wait([
|
||||
self.indicate_subscriber(connection, attribute)
|
||||
asyncio.create_task(coroutine(connection, attribute, value, force))
|
||||
for connection in connections
|
||||
])
|
||||
|
||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection):
|
||||
if connection.handle in self.mtus:
|
||||
del self.mtus[connection.handle]
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
import logging
|
||||
from colors import color
|
||||
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
|
||||
from .core import name_or_number
|
||||
from .gatt import ATT_PDU, ATT_CID
|
||||
from .l2cap import (
|
||||
@@ -73,6 +75,9 @@ class PacketTracer:
|
||||
if l2cap_pdu.cid == ATT_CID:
|
||||
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(att_pdu)
|
||||
elif l2cap_pdu.cid == SMP_CID:
|
||||
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(smp_command)
|
||||
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
|
||||
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(control_frame)
|
||||
|
||||
@@ -155,6 +155,7 @@ SMP_CT2_AUTHREQ = 0b00100000
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -879,7 +880,7 @@ class Session:
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def derive_ltk(self):
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
@@ -914,7 +915,7 @@ class Session:
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = crypto.h7(
|
||||
@@ -946,7 +947,7 @@ class Session:
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
|
||||
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
|
||||
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = crypto.h7(
|
||||
@@ -980,12 +981,7 @@ class Session:
|
||||
self.peer_expected_distributions.remove(command_class)
|
||||
logger.debug(f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}')
|
||||
if not self.peer_expected_distributions:
|
||||
# The initiator can now send its keys
|
||||
if self.is_initiator:
|
||||
self.distribute_keys()
|
||||
|
||||
# Nothing left to expect, we're done
|
||||
asyncio.create_task(self.on_pairing())
|
||||
self.on_peer_key_distribution_complete()
|
||||
else:
|
||||
logger.warn(color(f'!!! unexpected key distribution command: {command_class.__name__}', 'red'))
|
||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
||||
@@ -1006,12 +1002,23 @@ class Session:
|
||||
self.connection.remove_listener('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh)
|
||||
self.manager.on_session_end(self)
|
||||
|
||||
def on_peer_key_distribution_complete(self):
|
||||
# The initiator can now send its keys
|
||||
if self.is_initiator:
|
||||
self.distribute_keys()
|
||||
|
||||
asyncio.create_task(self.on_pairing())
|
||||
|
||||
def on_connection_encryption_change(self):
|
||||
if self.connection.is_encrypted:
|
||||
if self.is_responder:
|
||||
# The responder distributes its keys first, the initiator later
|
||||
self.distribute_keys()
|
||||
|
||||
# If we're not expecting key distributions from the peer, we're done
|
||||
if not self.peer_expected_distributions:
|
||||
self.on_peer_key_distribution_complete()
|
||||
|
||||
def on_connection_encryption_key_refresh(self):
|
||||
# Do as if the connection had just been encrypted
|
||||
self.on_connection_encryption_change()
|
||||
|
||||
@@ -22,6 +22,7 @@ import struct
|
||||
import pytest
|
||||
|
||||
from bumble.controller import Controller
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.link import LocalLink
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.host import Host
|
||||
@@ -53,29 +54,29 @@ def basic_check(x):
|
||||
parsed = ATT_PDU.from_bytes(pdu)
|
||||
x_str = str(x)
|
||||
parsed_str = str(parsed)
|
||||
assert(x_str == parsed_str)
|
||||
assert x_str == parsed_str
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_UUID():
|
||||
u = UUID.from_16_bits(0x7788)
|
||||
assert(str(u) == 'UUID-16:7788')
|
||||
assert str(u) == 'UUID-16:7788'
|
||||
u = UUID.from_32_bits(0x11223344)
|
||||
assert(str(u) == 'UUID-32:11223344')
|
||||
assert str(u) == 'UUID-32:11223344'
|
||||
u = UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
||||
assert(str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
||||
assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||
v = UUID(str(u))
|
||||
assert(str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
||||
assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||
w = UUID.from_bytes(v.to_bytes())
|
||||
assert(str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
|
||||
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
|
||||
|
||||
u1 = UUID.from_16_bits(0x1234)
|
||||
b1 = u1.to_bytes(force_128 = True)
|
||||
u2 = UUID.from_bytes(b1)
|
||||
assert(u1 == u2)
|
||||
assert u1 == u2
|
||||
|
||||
u3 = UUID.from_16_bits(0x180a)
|
||||
assert(str(u3) == 'UUID-16:180A (Device Information)')
|
||||
assert str(u3) == 'UUID-16:180A (Device Information)'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -98,6 +99,122 @@ def test_ATT_Read_By_Group_Type_Request():
|
||||
basic_check(pdu)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_characteristic_encoding():
|
||||
class Foo(Characteristic):
|
||||
def encode_value(self, value):
|
||||
return bytes([value])
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
return value_bytes[0]
|
||||
|
||||
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123)
|
||||
x = c.read_value(None)
|
||||
assert x == bytes([123])
|
||||
c.write_value(None, bytes([122]))
|
||||
assert c.value == 122
|
||||
|
||||
class FooProxy(CharacteristicProxy):
|
||||
def __init__(self, characteristic):
|
||||
super().__init__(
|
||||
characteristic.client,
|
||||
characteristic.handle,
|
||||
characteristic.end_group_handle,
|
||||
characteristic.uuid,
|
||||
characteristic.properties
|
||||
)
|
||||
|
||||
def encode_value(self, value):
|
||||
return bytes([value])
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
return value_bytes[0]
|
||||
|
||||
[client, server] = TwoDevices().devices
|
||||
|
||||
characteristic = Characteristic(
|
||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE | Characteristic.WRITEABLE,
|
||||
bytes([123])
|
||||
)
|
||||
|
||||
service = Service(
|
||||
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
|
||||
[characteristic]
|
||||
)
|
||||
server.add_service(service)
|
||||
|
||||
await client.power_on()
|
||||
await server.power_on()
|
||||
connection = await client.connect(server.random_address)
|
||||
peer = Peer(connection)
|
||||
|
||||
await peer.discover_services()
|
||||
await peer.discover_characteristics()
|
||||
c = peer.get_characteristics_by_uuid(characteristic.uuid)
|
||||
assert len(c) == 1
|
||||
c = c[0]
|
||||
cp = FooProxy(c)
|
||||
|
||||
v = await cp.read_value()
|
||||
assert v == 123
|
||||
await cp.write_value(124)
|
||||
await async_barrier()
|
||||
assert characteristic.value == bytes([124])
|
||||
|
||||
last_change = None
|
||||
|
||||
def on_change(value):
|
||||
nonlocal last_change
|
||||
last_change = value
|
||||
|
||||
await c.subscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change == characteristic.value
|
||||
last_change = None
|
||||
|
||||
await server.notify_subscribers(characteristic, value=bytes([125]))
|
||||
await async_barrier()
|
||||
assert last_change == bytes([125])
|
||||
last_change = None
|
||||
|
||||
await c.unsubscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change is None
|
||||
|
||||
await cp.subscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change == characteristic.value[0]
|
||||
last_change = None
|
||||
|
||||
await server.notify_subscribers(characteristic, value=bytes([126]))
|
||||
await async_barrier()
|
||||
assert last_change == 126
|
||||
last_change = None
|
||||
|
||||
await cp.unsubscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change is None
|
||||
|
||||
cd = DelegatedCharacteristicAdapter(c, decode=lambda x: x[0])
|
||||
await cd.subscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change == characteristic.value[0]
|
||||
last_change = None
|
||||
|
||||
await cd.unsubscribe(on_change)
|
||||
await server.notify_subscribers(characteristic)
|
||||
await async_barrier()
|
||||
assert last_change is None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_CharacteristicAdapter():
|
||||
# Check that the CharacteristicAdapter base class is transparent
|
||||
@@ -106,21 +223,21 @@ def test_CharacteristicAdapter():
|
||||
a = CharacteristicAdapter(c)
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == v)
|
||||
assert value == v
|
||||
|
||||
v = bytes([3, 4, 5])
|
||||
a.write_value(None, v)
|
||||
assert(c.value == v)
|
||||
assert c.value == v
|
||||
|
||||
# Simple delegated adapter
|
||||
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == bytes(reversed(v)))
|
||||
assert value == bytes(reversed(v))
|
||||
|
||||
v = bytes([3, 4, 5])
|
||||
a.write_value(None, v)
|
||||
assert(a.value == bytes(reversed(v)))
|
||||
assert a.value == bytes(reversed(v))
|
||||
|
||||
# Packed adapter with single element format
|
||||
v = 1234
|
||||
@@ -129,10 +246,10 @@ def test_CharacteristicAdapter():
|
||||
a = PackedCharacteristicAdapter(c, '>H')
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == pv)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
assert(a.value == v)
|
||||
assert a.value == v
|
||||
|
||||
# Packed adapter with multi-element format
|
||||
v1 = 1234
|
||||
@@ -142,10 +259,10 @@ def test_CharacteristicAdapter():
|
||||
a = PackedCharacteristicAdapter(c, '>HH')
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == pv)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
assert(a.value == (v1, v2))
|
||||
assert a.value == (v1, v2)
|
||||
|
||||
# Mapped adapter
|
||||
v1 = 1234
|
||||
@@ -156,10 +273,10 @@ def test_CharacteristicAdapter():
|
||||
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == pv)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
assert(a.value == mapped)
|
||||
assert a.value == mapped
|
||||
|
||||
# UTF-8 adapter
|
||||
v = 'Hello π'
|
||||
@@ -168,10 +285,10 @@ def test_CharacteristicAdapter():
|
||||
a = UTF8CharacteristicAdapter(c)
|
||||
|
||||
value = a.read_value(None)
|
||||
assert(value == ev)
|
||||
assert value == ev
|
||||
c.value = None
|
||||
a.write_value(None, ev)
|
||||
assert(a.value == v)
|
||||
assert a.value == v
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -179,13 +296,13 @@ def test_CharacteristicValue():
|
||||
b = bytes([1, 2, 3])
|
||||
c = CharacteristicValue(read=lambda _: b)
|
||||
x = c.read(None)
|
||||
assert(x == b)
|
||||
assert x == b
|
||||
|
||||
result = []
|
||||
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
|
||||
z = object()
|
||||
c.write(z, b)
|
||||
assert(result == [(z, b)])
|
||||
assert result == [(z, b)]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -265,35 +382,35 @@ async def test_read_write():
|
||||
await peer.discover_services()
|
||||
await peer.discover_characteristics()
|
||||
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c1 = c[0]
|
||||
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c2 = c[0]
|
||||
|
||||
v1 = await peer.read_value(c1)
|
||||
assert(v1 == b'')
|
||||
assert v1 == b''
|
||||
b = bytes([1, 2, 3])
|
||||
await peer.write_value(c1, b)
|
||||
await async_barrier()
|
||||
assert(characteristic1.value == b)
|
||||
assert characteristic1.value == b
|
||||
v1 = await peer.read_value(c1)
|
||||
assert(v1 == b)
|
||||
assert(type(characteristic1._last_value) is tuple)
|
||||
assert(len(characteristic1._last_value) == 2)
|
||||
assert(str(characteristic1._last_value[0].peer_address) == str(client.random_address))
|
||||
assert(characteristic1._last_value[1] == b)
|
||||
assert v1 == b
|
||||
assert type(characteristic1._last_value is tuple)
|
||||
assert len(characteristic1._last_value) == 2
|
||||
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address)
|
||||
assert characteristic1._last_value[1] == b
|
||||
bb = bytes([3, 4, 5, 6])
|
||||
characteristic1.value = bb
|
||||
v1 = await peer.read_value(c1)
|
||||
assert(v1 == bb)
|
||||
assert v1 == bb
|
||||
|
||||
await peer.write_value(c2, b)
|
||||
await async_barrier()
|
||||
assert(type(characteristic2._last_value) is tuple)
|
||||
assert(len(characteristic2._last_value) == 2)
|
||||
assert(str(characteristic2._last_value[0].peer_address) == str(client.random_address))
|
||||
assert(characteristic2._last_value[1] == b)
|
||||
assert type(characteristic2._last_value is tuple)
|
||||
assert len(characteristic2._last_value) == 2
|
||||
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address)
|
||||
assert characteristic2._last_value[1] == b
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -324,26 +441,26 @@ async def test_read_write2():
|
||||
|
||||
await peer.discover_services()
|
||||
c = peer.get_services_by_uuid(service1.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
s = c[0]
|
||||
await s.discover_characteristics()
|
||||
c = s.get_characteristics_by_uuid(characteristic1.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c1 = c[0]
|
||||
|
||||
v1 = await c1.read_value()
|
||||
assert(v1 == v)
|
||||
assert v1 == v
|
||||
|
||||
a1 = PackedCharacteristicAdapter(c1, '>I')
|
||||
v1 = await a1.read_value()
|
||||
assert(v1 == struct.unpack('>I', v)[0])
|
||||
assert v1 == struct.unpack('>I', v)[0]
|
||||
|
||||
b = bytes([0x55, 0x66, 0x77, 0x88])
|
||||
await a1.write_value(struct.unpack('>I', b)[0])
|
||||
await async_barrier()
|
||||
assert(characteristic1.value == b)
|
||||
assert characteristic1.value == b
|
||||
v1 = await a1.read_value()
|
||||
assert(v1 == struct.unpack('>I', b)[0])
|
||||
assert v1 == struct.unpack('>I', b)[0]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -410,13 +527,13 @@ async def test_subscribe_notify():
|
||||
await peer.discover_services()
|
||||
await peer.discover_characteristics()
|
||||
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c1 = c[0]
|
||||
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c2 = c[0]
|
||||
c = peer.get_characteristics_by_uuid(characteristic3.uuid)
|
||||
assert(len(c) == 1)
|
||||
assert len(c) == 1
|
||||
c3 = c[0]
|
||||
|
||||
c1._called = False
|
||||
@@ -429,23 +546,32 @@ async def test_subscribe_notify():
|
||||
c1.on('update', on_c1_update)
|
||||
await peer.subscribe(c1)
|
||||
await async_barrier()
|
||||
assert(server._last_subscription[1] == characteristic1)
|
||||
assert(server._last_subscription[2])
|
||||
assert(not server._last_subscription[3])
|
||||
assert(characteristic1._last_subscription[1])
|
||||
assert(not characteristic1._last_subscription[2])
|
||||
assert server._last_subscription[1] == characteristic1
|
||||
assert server._last_subscription[2]
|
||||
assert not server._last_subscription[3]
|
||||
assert characteristic1._last_subscription[1]
|
||||
assert not characteristic1._last_subscription[2]
|
||||
await server.indicate_subscribers(characteristic1)
|
||||
await async_barrier()
|
||||
assert(not c1._called)
|
||||
assert not c1._called
|
||||
await server.notify_subscribers(characteristic1)
|
||||
await async_barrier()
|
||||
assert(c1._called)
|
||||
assert(c1._last_update == characteristic1.value)
|
||||
assert c1._called
|
||||
assert c1._last_update == characteristic1.value
|
||||
|
||||
c1._called = False
|
||||
c1._last_update = None
|
||||
c1_value = characteristic1.value
|
||||
await server.notify_subscribers(characteristic1, bytes([0, 1, 2]))
|
||||
await async_barrier()
|
||||
assert c1._called
|
||||
assert c1._last_update == bytes([0, 1, 2])
|
||||
assert characteristic1.value == c1_value
|
||||
|
||||
c1._called = False
|
||||
await peer.unsubscribe(c1)
|
||||
await server.notify_subscribers(characteristic1)
|
||||
assert(not c1._called)
|
||||
assert not c1._called
|
||||
|
||||
c2._called = False
|
||||
c2._last_update = None
|
||||
@@ -458,17 +584,17 @@ async def test_subscribe_notify():
|
||||
await async_barrier()
|
||||
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||
await async_barrier()
|
||||
assert(not c2._called)
|
||||
assert not c2._called
|
||||
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||
await async_barrier()
|
||||
assert(c2._called)
|
||||
assert(c2._last_update == characteristic2.value)
|
||||
assert c2._called
|
||||
assert c2._last_update == characteristic2.value
|
||||
|
||||
c2._called = False
|
||||
await peer.unsubscribe(c2, on_c2_update)
|
||||
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
|
||||
await async_barrier()
|
||||
assert(not c2._called)
|
||||
assert not c2._called
|
||||
|
||||
def on_c3_update(value):
|
||||
c3._called = True
|
||||
@@ -483,17 +609,17 @@ async def test_subscribe_notify():
|
||||
await async_barrier()
|
||||
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await async_barrier()
|
||||
assert(c3._called)
|
||||
assert(c3._last_update == characteristic3.value)
|
||||
assert(c3._called_2)
|
||||
assert(c3._last_update_2 == characteristic3.value)
|
||||
assert c3._called
|
||||
assert c3._last_update == characteristic3.value
|
||||
assert c3._called_2
|
||||
assert c3._last_update_2 == characteristic3.value
|
||||
characteristic3.value = bytes([1, 2, 3])
|
||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await async_barrier()
|
||||
assert(c3._called)
|
||||
assert(c3._last_update == characteristic3.value)
|
||||
assert(c3._called_2)
|
||||
assert(c3._last_update_2 == characteristic3.value)
|
||||
assert c3._called
|
||||
assert c3._last_update == characteristic3.value
|
||||
assert c3._called_2
|
||||
assert c3._last_update_2 == characteristic3.value
|
||||
|
||||
c3._called = False
|
||||
c3._called_2 = False
|
||||
@@ -501,8 +627,8 @@ async def test_subscribe_notify():
|
||||
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await async_barrier()
|
||||
assert(not c3._called)
|
||||
assert(not c3._called_2)
|
||||
assert not c3._called
|
||||
assert not c3._called_2
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -510,6 +636,8 @@ async def async_main():
|
||||
await test_read_write()
|
||||
await test_read_write2()
|
||||
await test_subscribe_notify()
|
||||
await test_characteristic_encoding()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -246,8 +246,7 @@ IO_CAP = [
|
||||
SC = [False, True]
|
||||
MITM = [False, True]
|
||||
# Key distribution is a 4-bit bitmask
|
||||
# IdKey is necessary for current SMP structure
|
||||
KEY_DIST = [i for i in range(16) if (i & SMP_ID_KEY_DISTRIBUTION_FLAG)]
|
||||
KEY_DIST = range(16)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('io_cap, sc, mitm, key_dist',
|
||||
|
||||
Reference in New Issue
Block a user