Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod a9e726545e wip 2022-09-20 14:58:20 -07:00
12 changed files with 315 additions and 525 deletions
+4 -18
View File
@@ -700,26 +700,16 @@ class Attribute(EventEmitter):
else:
self.value = value
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
def read_value(self, connection):
if read := getattr(self.value, 'read', None):
try:
value = read(connection)
return read(connection)
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
else:
value = self.value
return self.encode_value(value)
def write_value(self, connection, value_bytes):
value = self.decode_value(value_bytes)
return self.value
def write_value(self, connection, value):
if write := getattr(self.value, 'write', None):
try:
write(connection, value)
@@ -731,11 +721,7 @@ class Attribute(EventEmitter):
self.emit('write', connection, value)
def __repr__(self):
if type(self.value) is bytes:
value_str = self.value.hex()
else:
value_str = str(self.value)
if value_str:
if len(self.value) > 0:
value_string = f', value={self.value.hex()}'
else:
value_string = ''
+117 -28
View File
@@ -18,7 +18,8 @@
import json
import asyncio
import logging
from contextlib import asynccontextmanager, AsyncExitStack
import secrets
from contextlib import asynccontextmanager, AsyncExitStack
from .hci import *
from .host import Host
@@ -32,6 +33,8 @@ from . import smp
from . import sdp
from . import l2cap
from . import keys
from . import crypto
# -----------------------------------------------------------------------------
# Logging
@@ -51,6 +54,7 @@ DEVICE_DEFAULT_SCAN_RESPONSE_DATA = b''
DEVICE_DEFAULT_DATA_LENGTH = (27, 328, 27, 328)
DEVICE_DEFAULT_SCAN_INTERVAL = 60 # ms
DEVICE_DEFAULT_SCAN_WINDOW = 60 # ms
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
@@ -169,7 +173,6 @@ class Peer:
async def __aexit__(self, exc_type, exc_value, traceback):
pass
def __str__(self):
return f'{self.connection.peer_address} as {self.connection.role_name}'
@@ -202,11 +205,22 @@ class Connection(CompositeEventEmitter):
def on_connection_encryption_key_refresh(self):
pass
def __init__(self, device, handle, transport, peer_address, peer_resolvable_address, role, parameters):
def __init__(
self,
device,
handle,
transport,
local_address,
peer_address,
peer_resolvable_address,
role,
parameters
):
super().__init__()
self.device = device
self.handle = handle
self.transport = transport
self.local_address = local_address
self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only
@@ -297,7 +311,12 @@ class Connection(CompositeEventEmitter):
raise
def __str__(self):
return f'Connection(handle=0x{self.handle:04X}, role={self.role_name}, address={self.peer_address})'
return (
f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'local_address={self.local_address}, '
f'peer_address={self.peer_address})'
)
# -----------------------------------------------------------------------------
@@ -311,8 +330,10 @@ class DeviceConfiguration:
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.le_enabled = True
# LE host enable 2nd parameter
self.le_simultaneous_enabled = True
self.le_privacy_enabled = False
self.le_rpa_timeout = DEVICE_DEFAULT_LE_RPA_TIMEOUT
self.classic_enabled = False
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
self.connectable = True
@@ -320,19 +341,22 @@ class DeviceConfiguration:
self.advertising_data = bytes(
AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))])
)
self.irk = bytes(16) # This really must be changed for any level of security
self.irk = bytes([0xFF] * 16) # This really must be changed for any level of security
self.keystore = None
def load_from_dict(self, config):
# Load simple properties
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get('advertising_interval', self.advertising_interval_min)
self.advertising_interval_max = self.advertising_interval_min
self.keystore = config.get('keystore')
self.le_enabled = config.get('le_enabled', self.le_enabled)
self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled)
self.le_privacy_enabled = config.get('le_privacy_enabled', self.le_privacy_enabled)
self.le_rpa_timeout = config.get('le_rpa_timeout', self.le_rpa_timeout)
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled)
self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled)
self.connectable = config.get('connectable', self.connectable)
@@ -352,6 +376,10 @@ class DeviceConfiguration:
advertising_data = config.get('advertising_data')
if advertising_data:
self.advertising_data = bytes.fromhex(advertising_data)
else:
self.advertising_data = bytes(
AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))])
)
def load_from_file(self, filename):
with open(filename, 'r') as file:
@@ -458,9 +486,9 @@ class Device(CompositeEventEmitter):
self.connecting = False
self.disconnecting = False
self.connections = {} # Connections, by connection handle
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.le_rpa_task = None
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
@@ -468,6 +496,7 @@ class Device(CompositeEventEmitter):
config = DeviceConfiguration()
self.name = config.name
self.random_address = config.address
self.identity_address = config.address
self.class_of_device = config.class_of_device
self.scan_response_data = config.scan_response_data
self.advertising_data = config.advertising_data
@@ -477,6 +506,9 @@ class Device(CompositeEventEmitter):
self.irk = config.irk
self.le_enabled = config.le_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.classic_enabled = config.classic_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_sc_enabled = config.classic_sc_enabled
self.discoverable = config.discoverable
@@ -490,11 +522,12 @@ class Device(CompositeEventEmitter):
if address:
if type(address) is str:
address = Address(address)
self.random_address = address
self.random_address = address
self.identity_address = address
# Setup SMP
# TODO: allow using a public address
self.smp_manager = smp.Manager(self, self.random_address)
self.smp_manager = smp.Manager(self, self.random_address, self.identity_address)
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_CID, self.on_smp_pdu)
self.l2cap_channel_manager.register_fixed_channel(
@@ -591,6 +624,14 @@ class Device(CompositeEventEmitter):
))
if self.le_enabled:
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = self.generate_le_rpa()
logger.info(f'Initial RPA: {self.random_address}')
if self.le_rpa_timeout > 0:
# Start a task to periodically generate a new RPA
self.le_rpa_task = asyncio.create_task(self.run_le_rpa_generation())
# Set the controller address
await self.send_command(HCI_LE_Set_Random_Address_Command(
random_address = self.random_address
@@ -637,13 +678,48 @@ class Device(CompositeEventEmitter):
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
# Let the SMP manager know about the address
# TODO: allow using a public address
self.smp_manager.address = self.random_address
# Done
self.powered_on = True
async def run_le_rpa_generation(self):
while self.le_rpa_timeout != 0:
await asyncio.sleep(self.le_rpa_timeout)
# Check if this is a good time to rotate the address
if self.advertising or self.scanning or self.connecting:
logger.debug('skipping RPA rotation')
continue
random_address = self.generate_le_rpa()
response = await self.send_command(HCI_LE_Set_Random_Address_Command(
random_address = self.random_address
))
if response.return_parameters == HCI_SUCCESS:
logger.info(f'New RPA: {random_address}')
self.random_address = random_address
else:
logger.warning(f'failed to set RPA: {response.return_parameters}')
def generate_le_rpa(self):
# See 1.3.2.2 Private device address generation
# Generate `prand`
while True:
# Generate a 22-bit random number for the random part of `prand`
prand_random = secrets.randbelow(0x400000)
# As least on bit shall be 0 and one bit shall be 1
if prand_random != 0 and prand_random != 0x3FFFFF:
break
prand = prand_random | 0x400000 # The two MSBs are |1|0|
# Generate `hash`
hash = crypto.ah(self.irk, struct.pack('<I', prand)[:3])
# Generate the address from `prand` and `hash`
return Address(hash + struct.pack('<I', prand)[:3], Address.RANDOM_IDENTITY_ADDRESS)
async def start_advertising(self, auto_restart=False):
self.auto_restart_advertising = auto_restart
@@ -675,18 +751,24 @@ class Device(CompositeEventEmitter):
))
# Enable advertising
await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
response = await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
advertising_enable = 1
))
if response.return_parameters != HCI_SUCCESS:
logger.warning(f'HCI_LE_Set_Advertising_Enable_Command failed ({response.return_parameters})')
raise HCI_Error(response.return_parameters)
self.advertising = True
async def stop_advertising(self):
# Disable advertising
if self.advertising:
await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
response = await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
advertising_enable = 0
))
if response.return_parameters != HCI_SUCCESS:
logger.warning(f'HCI_LE_Set_Advertising_Enable_Command failed ({response.return_parameters})')
raise HCI_Error(response.return_parameters)
self.advertising = False
@@ -721,17 +803,23 @@ class Device(CompositeEventEmitter):
))
# Enable scanning
await self.send_command(HCI_LE_Set_Scan_Enable_Command(
response = await self.send_command(HCI_LE_Set_Scan_Enable_Command(
le_scan_enable = 1,
filter_duplicates = 1 if filter_duplicates else 0
))
if response.return_parameters != HCI_SUCCESS:
raise HCI_Error(response.return_parameters)
self.scanning = True
async def stop_scanning(self):
await self.send_command(HCI_LE_Set_Scan_Enable_Command(
response = await self.send_command(HCI_LE_Set_Scan_Enable_Command(
le_scan_enable = 0,
filter_duplicates = 0
))
if response.return_parameters != HCI_SUCCESS:
raise HCI_Error(response.return_parameters)
self.scanning = False
@property
@@ -1210,17 +1298,17 @@ class Device(CompositeEventEmitter):
def add_services(self, services):
self.gatt_server.add_services(services)
async def notify_subscriber(self, connection, attribute, value=None, force=False):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscriber(self, connection, attribute, force=False):
await self.gatt_server.notify_subscriber(connection, attribute, force)
async def notify_subscribers(self, attribute, value=None, force=False):
await self.gatt_server.notify_subscribers(attribute, value, force)
async def notify_subscribers(self, attribute, force=False):
await self.gatt_server.notify_subscribers(attribute, force)
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscriber(self, connection, attribute, force=False):
await self.gatt_server.indicate_subscriber(connection, attribute, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
await self.gatt_server.indicate_subscribers(attribute, value, force)
async def indicate_subscribers(self, attribute):
await self.gatt_server.indicate_subscribers(attribute)
@host_event_handler
def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters):
@@ -1242,6 +1330,7 @@ class Device(CompositeEventEmitter):
self,
connection_handle,
transport,
self.public_address if transport == BT_BR_EDR_TRANSPORT else self.random_address,
peer_address,
peer_resolvable_address,
role,
+4 -32
View File
@@ -303,7 +303,6 @@ class CharacteristicAdapter:
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if (
asyncio.iscoroutinefunction(characteristic.read_value) and
@@ -318,21 +317,11 @@ class CharacteristicAdapter:
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in {
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe'
}:
if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
@@ -356,26 +345,9 @@ class CharacteristicAdapter:
return value
def wrapped_subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
return self.wrapped_characteristic.subscribe(
None if subscriber is None else lambda value: subscriber(self.decode_value(value))
)
def __str__(self):
wrapped = str(self.wrapped_characteristic)
+9 -39
View File
@@ -58,16 +58,10 @@ class AttributeProxy(EventEmitter):
self.type = attribute_type
async def read_value(self, no_long_read=False):
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
return await self.client.read_value(self.handle, no_long_read)
async def write_value(self, value, with_response=False):
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
return await self.client.write_value(self.handle, value, with_response)
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
@@ -104,7 +98,6 @@ class CharacteristicProxy(AttributeProxy):
self.properties = properties
self.descriptors = []
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
@@ -115,25 +108,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return await self.client.subscribe(self, subscriber)
async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
@@ -163,6 +140,7 @@ class ProfileServiceProxy:
class Client:
def __init__(self, connection):
self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -216,7 +194,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return
# Send the request
self.mtu_exchange_done = True
@@ -229,10 +207,8 @@ class Client:
response
)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
return self.mtu
def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
@@ -594,18 +570,12 @@ class Client:
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
subscriber_set.remove(characteristic.handle)
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False):
'''
@@ -630,7 +600,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -652,7 +622,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
+59 -47
View File
@@ -40,12 +40,6 @@ from .gatt import *
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -55,8 +49,9 @@ class Server(EventEmitter):
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
@@ -174,7 +169,7 @@ class Server(EventEmitter):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -189,12 +184,13 @@ class Server(EventEmitter):
logger.debug(f'not notifying, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Get the value
value = attribute.read_value(connection)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
# Notify
notification = ATT_Handle_Value_Notification(
@@ -202,9 +198,27 @@ class Server(EventEmitter):
attribute_value = value
)
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
self.send_gatt_pdu(connection.handle, notification.to_bytes())
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscribers(self, attribute, force=False):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Notify for each connection
if connections:
await asyncio.wait([
self.notify_subscriber(connection, attribute, force)
for connection in connections
])
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -219,12 +233,13 @@ class Server(EventEmitter):
logger.debug(f'not indicating, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Get the value
value = attribute.read_value(connection)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
# Indicate
indication = ATT_Handle_Value_Indication(
@@ -249,32 +264,27 @@ class Server(EventEmitter):
finally:
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
async def indicate_subscribers(self, attribute):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
if subscribers.get(attribute.handle)
]
if connection is not None
]
# Indicate or notify for each connection
# Indicate for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait([
asyncio.create_task(coroutine(connection, attribute, value, force))
self.indicate_subscriber(connection, attribute)
for connection in connections
])
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -315,6 +325,9 @@ class Server(EventEmitter):
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
#######################################################
# ATT handlers
#######################################################
@@ -334,16 +347,12 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
def on_att_find_information_request(self, connection, request):
'''
@@ -360,7 +369,7 @@ class Server(EventEmitter):
return
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = self.get_mtu(connection) - 2
attributes = []
uuid_size = 0
for attribute in (
@@ -411,7 +420,7 @@ class Server(EventEmitter):
'''
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = self.get_mtu(connection) - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -459,7 +468,8 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = connection.att_mtu - 2
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -472,7 +482,7 @@ class Server(EventEmitter):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(connection.att_mtu - 4, 253)
max_attribute_size = min(mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -512,7 +522,7 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value))
value_size = min(self.get_mtu(connection) - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
@@ -531,6 +541,7 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
@@ -538,14 +549,14 @@ class Server(EventEmitter):
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
)
elif len(value) <= connection.att_mtu - 1:
elif len(value) <= mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
)
else:
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
part_size = min(mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
)
@@ -574,7 +585,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
return
pdu_space_available = connection.att_mtu - 2
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -585,7 +597,7 @@ class Server(EventEmitter):
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(connection.att_mtu - 6, 251)
max_attribute_size = min(mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
+5 -3
View File
@@ -1375,9 +1375,11 @@ class HCI_Error(ProtocolError):
class HCI_StatusError(ProtocolError):
def __init__(self, response):
super().__init__(response.status,
error_namespace=HCI_Command.command_name(response.command_opcode),
error_name=HCI_Constant.status_name(response.status))
super().__init__(
response.status,
error_namespace=HCI_Command.command_name(response.command_opcode),
error_name=HCI_Constant.status_name(response.status)
)
# -----------------------------------------------------------------------------
-5
View File
@@ -18,8 +18,6 @@
import logging
from colors import color
from bumble.smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
@@ -75,9 +73,6 @@ class PacketTracer:
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
+16 -23
View File
@@ -155,7 +155,6 @@ SMP_CT2_AUTHREQ = 0b00100000
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -639,13 +638,13 @@ class Session:
# Set up addresses
peer_address = connection.peer_resolvable_address or connection.peer_address
if self.is_initiator:
self.ia = bytes(manager.address)
self.iat = 1 if manager.address.is_random else 0
self.ia = bytes(connection.local_address)
self.iat = 1 if connection.local_address.is_random else 0
self.ra = bytes(peer_address)
self.rat = 1 if peer_address.is_random else 0
else:
self.ra = bytes(manager.address)
self.rat = 1 if manager.address.is_random else 0
self.ra = bytes(connection.local_address)
self.rat = 1 if connection.local_address.is_random else 0
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
@@ -907,8 +906,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
addr_type = self.manager.identity_address.address_type,
bd_addr = self.manager.identity_address
))
# Distribute CSRK
@@ -939,8 +938,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
addr_type = self.manager.identity_address.address_type,
bd_addr = self.manager.identity_address
))
# Distribute CSRK
@@ -981,7 +980,12 @@ class Session:
self.peer_expected_distributions.remove(command_class)
logger.debug(f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}')
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
# Nothing left to expect, we're done
asyncio.create_task(self.on_pairing())
else:
logger.warn(color(f'!!! unexpected key distribution command: {command_class.__name__}', 'red'))
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
@@ -1002,23 +1006,12 @@ class Session:
self.connection.remove_listener('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh)
self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self):
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
asyncio.create_task(self.on_pairing())
def on_connection_encryption_change(self):
if self.connection.is_encrypted:
if self.is_responder:
# The responder distributes its keys first, the initiator later
self.distribute_keys()
# If we're not expecting key distributions from the peer, we're done
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self):
# Do as if the connection had just been encrypted
self.on_connection_encryption_change()
@@ -1486,10 +1479,10 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol
'''
def __init__(self, device, address):
def __init__(self, device, address, identity_address):
super().__init__()
self.device = device
self.address = address
self.identity_address = identity_address
self.sessions = {}
self._ecc_key = None
self.pairing_config_factory = lambda connection: PairingConfig()
+22 -82
View File
@@ -56,19 +56,18 @@ async def open_usb_transport(spec):
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_ACL_OUT = 0x02
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device, acl_out):
def __init__(self, device):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
@@ -117,7 +116,7 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
self.acl_out,
USB_ENDPOINT_ACL_OUT,
packet[1:],
callback=self.on_packet_sent
)
@@ -153,12 +152,10 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device, acl_in, events_in):
def __init__(self, context, device):
super().__init__()
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
@@ -175,7 +172,7 @@ async def open_usb_transport(spec):
# Set up transfer objects for input
self.events_in_transfer = device.getTransfer()
self.events_in_transfer.setInterrupt(
self.events_in,
USB_ENDPOINT_EVENTS_IN,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET
@@ -184,7 +181,7 @@ async def open_usb_transport(spec):
self.acl_in_transfer = device.getTransfer()
self.acl_in_transfer.setBulk(
self.acl_in,
USB_ENDPOINT_ACL_IN,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET
@@ -251,7 +248,7 @@ async def open_usb_transport(spec):
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink):
def __init__(self, context, device, interface, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
@@ -260,10 +257,6 @@ async def open_usb_transport(spec):
# Get exclusive access
device.claimInterface(interface)
# Set the alternate setting if not the default
if setting != 0:
device.setInterfaceAltSetting(interface, setting)
# The source and sink can now start
source.start()
sink.start()
@@ -320,64 +313,11 @@ async def open_usb_transport(spec):
raise ValueError('device not found')
logger.debug(f'USB Device: {found}')
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
setting.getClass() != USB_DEVICE_CLASS_WIRELESS_CONTROLLER or
setting.getSubClass() != USB_DEVICE_SUBCLASS_RF_CONTROLLER or
setting.getProtocol() != USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
):
continue
events_in = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT:
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None:
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in
)
else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}')
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
f'interface={interface}, '
f'setting={setting}, '
f'acl_in=0x{acl_in:02X}, '
f'acl_out=0x{acl_out:02X}, '
f'events_in=0x{events_in:02X}, '
)
device = found.open()
# Use the first interface
interface = 0
# Detach the kernel driver if supported and needed
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
@@ -389,21 +329,21 @@ async def open_usb_transport(spec):
# Set the configuration if needed
try:
current_configuration = device.getConfiguration()
logger.debug(f'current configuration = {current_configuration}')
configuration = device.getConfiguration()
logger.debug(f'current configuration = {configuration}')
except usb1.USBError:
current_configuration = 0
configuration = 0
if current_configuration != configuration:
if configuration != 1:
try:
logger.debug(f'setting configuration {configuration}')
device.setConfiguration(configuration)
logger.debug('setting configuration 1')
device.setConfiguration(1)
except usb1.USBError:
logger.warning('failed to set configuration')
logger.warning('failed to set configuration 1')
source = UsbPacketSource(context, device, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
source = UsbPacketSource(context, device)
sink = UsbPacketSink(device)
return UsbTransport(context, device, interface, source, sink)
except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
context.close()
+1 -1
View File
@@ -20,7 +20,7 @@ import sys
import os
import logging
from colors import color
from bumble.device import Device, Peer
from bumble.device import Device
from bumble.transport import open_transport
from bumble.profiles.battery_service import BatteryServiceProxy
+76 -246
View File
@@ -22,7 +22,6 @@ import struct
import pytest
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -54,29 +53,29 @@ def basic_check(x):
parsed = ATT_PDU.from_bytes(pdu)
x_str = str(x)
parsed_str = str(parsed)
assert x_str == parsed_str
assert(x_str == parsed_str)
# -----------------------------------------------------------------------------
def test_UUID():
u = UUID.from_16_bits(0x7788)
assert str(u) == 'UUID-16:7788'
assert(str(u) == 'UUID-16:7788')
u = UUID.from_32_bits(0x11223344)
assert str(u) == 'UUID-32:11223344'
assert(str(u) == 'UUID-32:11223344')
u = UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
assert(str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
v = UUID(str(u))
assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
assert(str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
w = UUID.from_bytes(v.to_bytes())
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
assert(str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
u1 = UUID.from_16_bits(0x1234)
b1 = u1.to_bytes(force_128 = True)
u2 = UUID.from_bytes(b1)
assert u1 == u2
assert(u1 == u2)
u3 = UUID.from_16_bits(0x180a)
assert str(u3) == 'UUID-16:180A (Device Information)'
assert(str(u3) == 'UUID-16:180A (Device Information)')
# -----------------------------------------------------------------------------
@@ -99,122 +98,6 @@ def test_ATT_Read_By_Group_Type_Request():
basic_check(pdu)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_characteristic_encoding():
class Foo(Characteristic):
def encode_value(self, value):
return bytes([value])
def decode_value(self, value_bytes):
return value_bytes[0]
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123)
x = c.read_value(None)
assert x == bytes([123])
c.write_value(None, bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
def __init__(self, characteristic):
super().__init__(
characteristic.client,
characteristic.handle,
characteristic.end_group_handle,
characteristic.uuid,
characteristic.properties
)
def encode_value(self, value):
return bytes([value])
def decode_value(self, value_bytes):
return value_bytes[0]
[client, server] = LinkedDevices().devices[:2]
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123])
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic]
)
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic.uuid)
assert len(c) == 1
c = c[0]
cp = FooProxy(c)
v = await cp.read_value()
assert v == 123
await cp.write_value(124)
await async_barrier()
assert characteristic.value == bytes([124])
last_change = None
def on_change(value):
nonlocal last_change
last_change = value
await c.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value
last_change = None
await server.notify_subscribers(characteristic, value=bytes([125]))
await async_barrier()
assert last_change == bytes([125])
last_change = None
await c.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
await cp.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value[0]
last_change = None
await server.notify_subscribers(characteristic, value=bytes([126]))
await async_barrier()
assert last_change == 126
last_change = None
await cp.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
cd = DelegatedCharacteristicAdapter(c, decode=lambda x: x[0])
await cd.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value[0]
last_change = None
await cd.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
# -----------------------------------------------------------------------------
def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
@@ -223,21 +106,21 @@ def test_CharacteristicAdapter():
a = CharacteristicAdapter(c)
value = a.read_value(None)
assert value == v
assert(value == v)
v = bytes([3, 4, 5])
a.write_value(None, v)
assert c.value == v
assert(c.value == v)
# Simple delegated adapter
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
value = a.read_value(None)
assert value == bytes(reversed(v))
assert(value == bytes(reversed(v)))
v = bytes([3, 4, 5])
a.write_value(None, v)
assert a.value == bytes(reversed(v))
assert(a.value == bytes(reversed(v)))
# Packed adapter with single element format
v = 1234
@@ -246,10 +129,10 @@ def test_CharacteristicAdapter():
a = PackedCharacteristicAdapter(c, '>H')
value = a.read_value(None)
assert value == pv
assert(value == pv)
c.value = None
a.write_value(None, pv)
assert a.value == v
assert(a.value == v)
# Packed adapter with multi-element format
v1 = 1234
@@ -259,10 +142,10 @@ def test_CharacteristicAdapter():
a = PackedCharacteristicAdapter(c, '>HH')
value = a.read_value(None)
assert value == pv
assert(value == pv)
c.value = None
a.write_value(None, pv)
assert a.value == (v1, v2)
assert(a.value == (v1, v2))
# Mapped adapter
v1 = 1234
@@ -273,10 +156,10 @@ def test_CharacteristicAdapter():
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = a.read_value(None)
assert value == pv
assert(value == pv)
c.value = None
a.write_value(None, pv)
assert a.value == mapped
assert(a.value == mapped)
# UTF-8 adapter
v = 'Hello π'
@@ -285,10 +168,10 @@ def test_CharacteristicAdapter():
a = UTF8CharacteristicAdapter(c)
value = a.read_value(None)
assert value == ev
assert(value == ev)
c.value = None
a.write_value(None, ev)
assert a.value == v
assert(a.value == v)
# -----------------------------------------------------------------------------
@@ -296,25 +179,24 @@ def test_CharacteristicValue():
b = bytes([1, 2, 3])
c = CharacteristicValue(read=lambda _: b)
x = c.read(None)
assert x == b
assert(x == b)
result = []
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
z = object()
c.write(z, b)
assert result == [(z, b)]
assert(result == [(z, b)])
# -----------------------------------------------------------------------------
class LinkedDevices:
class TwoDevices:
def __init__(self):
self.connections = [None, None, None]
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link),
Controller('C3', link = self.link)
Controller('C2', link = self.link)
]
self.devices = [
Device(
@@ -322,16 +204,12 @@ class LinkedDevices:
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F1:F2:F3:F4:F5:F6',
address = 'F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
)
]
self.paired = [None, None, None]
self.paired = [None, None]
# -----------------------------------------------------------------------------
@@ -344,7 +222,7 @@ async def async_barrier():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write():
[client, server] = LinkedDevices().devices[:2]
[client, server] = TwoDevices().devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -387,41 +265,41 @@ async def test_read_write():
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
assert(len(c) == 1)
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert len(c) == 1
assert(len(c) == 1)
c2 = c[0]
v1 = await peer.read_value(c1)
assert v1 == b''
assert(v1 == b'')
b = bytes([1, 2, 3])
await peer.write_value(c1, b)
await async_barrier()
assert characteristic1.value == b
assert(characteristic1.value == b)
v1 = await peer.read_value(c1)
assert v1 == b
assert type(characteristic1._last_value is tuple)
assert len(characteristic1._last_value) == 2
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address)
assert characteristic1._last_value[1] == b
assert(v1 == b)
assert(type(characteristic1._last_value) is tuple)
assert(len(characteristic1._last_value) == 2)
assert(str(characteristic1._last_value[0].peer_address) == str(client.random_address))
assert(characteristic1._last_value[1] == b)
bb = bytes([3, 4, 5, 6])
characteristic1.value = bb
v1 = await peer.read_value(c1)
assert v1 == bb
assert(v1 == bb)
await peer.write_value(c2, b)
await async_barrier()
assert type(characteristic2._last_value is tuple)
assert len(characteristic2._last_value) == 2
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address)
assert characteristic2._last_value[1] == b
assert(type(characteristic2._last_value) is tuple)
assert(len(characteristic2._last_value) == 2)
assert(str(characteristic2._last_value[0].peer_address) == str(client.random_address))
assert(characteristic2._last_value[1] == b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write2():
[client, server] = LinkedDevices().devices[:2]
[client, server] = TwoDevices().devices
v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic(
@@ -446,32 +324,32 @@ async def test_read_write2():
await peer.discover_services()
c = peer.get_services_by_uuid(service1.uuid)
assert len(c) == 1
assert(len(c) == 1)
s = c[0]
await s.discover_characteristics()
c = s.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
assert(len(c) == 1)
c1 = c[0]
v1 = await c1.read_value()
assert v1 == v
assert(v1 == v)
a1 = PackedCharacteristicAdapter(c1, '>I')
v1 = await a1.read_value()
assert v1 == struct.unpack('>I', v)[0]
assert(v1 == struct.unpack('>I', v)[0])
b = bytes([0x55, 0x66, 0x77, 0x88])
await a1.write_value(struct.unpack('>I', b)[0])
await async_barrier()
assert characteristic1.value == b
assert(characteristic1.value == b)
v1 = await a1.read_value()
assert v1 == struct.unpack('>I', b)[0]
assert(v1 == struct.unpack('>I', b)[0])
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_notify():
[client, server] = LinkedDevices().devices[:2]
[client, server] = TwoDevices().devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -532,13 +410,13 @@ async def test_subscribe_notify():
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
assert(len(c) == 1)
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert len(c) == 1
assert(len(c) == 1)
c2 = c[0]
c = peer.get_characteristics_by_uuid(characteristic3.uuid)
assert len(c) == 1
assert(len(c) == 1)
c3 = c[0]
c1._called = False
@@ -551,32 +429,23 @@ async def test_subscribe_notify():
c1.on('update', on_c1_update)
await peer.subscribe(c1)
await async_barrier()
assert server._last_subscription[1] == characteristic1
assert server._last_subscription[2]
assert not server._last_subscription[3]
assert characteristic1._last_subscription[1]
assert not characteristic1._last_subscription[2]
assert(server._last_subscription[1] == characteristic1)
assert(server._last_subscription[2])
assert(not server._last_subscription[3])
assert(characteristic1._last_subscription[1])
assert(not characteristic1._last_subscription[2])
await server.indicate_subscribers(characteristic1)
await async_barrier()
assert not c1._called
assert(not c1._called)
await server.notify_subscribers(characteristic1)
await async_barrier()
assert c1._called
assert c1._last_update == characteristic1.value
c1._called = False
c1._last_update = None
c1_value = characteristic1.value
await server.notify_subscribers(characteristic1, bytes([0, 1, 2]))
await async_barrier()
assert c1._called
assert c1._last_update == bytes([0, 1, 2])
assert characteristic1.value == c1_value
assert(c1._called)
assert(c1._last_update == characteristic1.value)
c1._called = False
await peer.unsubscribe(c1)
await server.notify_subscribers(characteristic1)
assert not c1._called
assert(not c1._called)
c2._called = False
c2._last_update = None
@@ -589,17 +458,17 @@ async def test_subscribe_notify():
await async_barrier()
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert not c2._called
assert(not c2._called)
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert c2._called
assert c2._last_update == characteristic2.value
assert(c2._called)
assert(c2._last_update == characteristic2.value)
c2._called = False
await peer.unsubscribe(c2, on_c2_update)
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert not c2._called
assert(not c2._called)
def on_c3_update(value):
c3._called = True
@@ -614,17 +483,17 @@ async def test_subscribe_notify():
await async_barrier()
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert c3._called
assert c3._last_update == characteristic3.value
assert c3._called_2
assert c3._last_update_2 == characteristic3.value
assert(c3._called)
assert(c3._last_update == characteristic3.value)
assert(c3._called_2)
assert(c3._last_update_2 == characteristic3.value)
characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert c3._called
assert c3._last_update == characteristic3.value
assert c3._called_2
assert c3._last_update_2 == characteristic3.value
assert(c3._called)
assert(c3._last_update == characteristic3.value)
assert(c3._called_2)
assert(c3._last_update_2 == characteristic3.value)
c3._called = False
c3._called_2 = False
@@ -632,44 +501,8 @@ async def test_subscribe_notify():
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert not c3._called
assert not c3._called_2
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]
d3.gatt_server.max_mtu = 100
d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)
await d1.power_on()
await d2.power_on()
await d3.power_on()
d1_connection = await d1.connect(d3.random_address)
assert len(d3_connections) == 1
assert d3_connections[0] is not None
d2_connection = await d2.connect(d3.random_address)
assert len(d3_connections) == 2
assert d3_connections[1] is not None
d1_peer = Peer(d1_connection)
d2_peer = Peer(d2_connection)
d1_client_mtu = await d1_peer.request_mtu(220)
assert d1_client_mtu == 100
assert d1_connection.att_mtu == 100
d2_client_mtu = await d2_peer.request_mtu(50)
assert d2_client_mtu == 50
assert d2_connection.att_mtu == 50
assert(not c3._called)
assert(not c3._called_2)
# -----------------------------------------------------------------------------
@@ -677,9 +510,6 @@ async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_characteristic_encoding()
await test_mtu_exchange()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
+2 -1
View File
@@ -246,7 +246,8 @@ IO_CAP = [
SC = [False, True]
MITM = [False, True]
# Key distribution is a 4-bit bitmask
KEY_DIST = range(16)
# IdKey is necessary for current SMP structure
KEY_DIST = [i for i in range(16) if (i & SMP_ID_KEY_DISTRIBUTION_FLAG)]
@pytest.mark.asyncio
@pytest.mark.parametrize('io_cap, sc, mitm, key_dist',