Compare commits

...

9 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
20dedbd923 maintain the att mtu only at the connection level 2022-10-04 20:04:43 -07:00
Gilles Boccon-Gibod
0edd6b731f Merge pull request #37 from google/gbg/gatt-notify-with-value
add support for notifying with a transient value
2022-10-04 10:33:04 -07:00
Gilles Boccon-Gibod
a2f18cffc9 Merge pull request #38 from google/gbg/usb-interface-discovery
add support for dynamic discovery of USB endpoint addresses
2022-09-21 11:40:13 -07:00
Gilles Boccon-Gibod
db5e52f1df add support for alternate settings 2022-09-20 22:25:40 -07:00
Gilles Boccon-Gibod
d7da5a9379 add support for dynamic discovery of USB endpoints 2022-09-20 16:39:12 -07:00
Gilles Boccon-Gibod
80569bc9f3 add support for notifying with a transient value 2022-09-06 12:42:35 -07:00
Gilles Boccon-Gibod
daa05b8996 Merge pull request #36 from google/gbg/pairing-with-no-distribution
gbg/pairing with no distribution
2022-09-02 10:17:31 -07:00
Gilles Boccon-Gibod
624e860762 support empty distributions in both directions 2022-08-30 18:50:48 -07:00
Gilles Boccon-Gibod
159cbf7774 support pairing with no key distribution 2022-08-30 18:28:24 -07:00
10 changed files with 494 additions and 193 deletions

View File

@@ -700,16 +700,26 @@ class Attribute(EventEmitter):
else:
self.value = value
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
def read_value(self, connection):
if read := getattr(self.value, 'read', None):
try:
return read(connection)
value = read(connection)
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
else:
return self.value
value = self.value
return self.encode_value(value)
def write_value(self, connection, value_bytes):
value = self.decode_value(value_bytes)
def write_value(self, connection, value):
if write := getattr(self.value, 'write', None):
try:
write(connection, value)
@@ -721,7 +731,11 @@ class Attribute(EventEmitter):
self.emit('write', connection, value)
def __repr__(self):
if len(self.value) > 0:
if type(self.value) is bytes:
value_str = self.value.hex()
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}'
else:
value_string = ''

View File

@@ -1210,17 +1210,17 @@ class Device(CompositeEventEmitter):
def add_services(self, services):
self.gatt_server.add_services(services)
async def notify_subscriber(self, connection, attribute, force=False):
await self.gatt_server.notify_subscriber(connection, attribute, force)
async def notify_subscriber(self, connection, attribute, value=None, force=False):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(self, attribute, force=False):
await self.gatt_server.notify_subscribers(attribute, force)
async def notify_subscribers(self, attribute, value=None, force=False):
await self.gatt_server.notify_subscribers(attribute, value, force)
async def indicate_subscriber(self, connection, attribute, force=False):
await self.gatt_server.indicate_subscriber(connection, attribute, force)
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(self, attribute):
await self.gatt_server.indicate_subscribers(attribute)
async def indicate_subscribers(self, attribute, value=None, force=False):
await self.gatt_server.indicate_subscribers(attribute, value, force)
@host_event_handler
def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters):

View File

@@ -303,6 +303,7 @@ class CharacteristicAdapter:
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if (
asyncio.iscoroutinefunction(characteristic.read_value) and
@@ -317,11 +318,21 @@ class CharacteristicAdapter:
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
if name in {
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe'
}:
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
@@ -345,9 +356,26 @@ class CharacteristicAdapter:
return value
def wrapped_subscribe(self, subscriber=None):
return self.wrapped_characteristic.subscribe(
None if subscriber is None else lambda value: subscriber(self.decode_value(value))
)
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self):
wrapped = str(self.wrapped_characteristic)

View File

@@ -58,10 +58,16 @@ class AttributeProxy(EventEmitter):
self.type = attribute_type
async def read_value(self, no_long_read=False):
return await self.client.read_value(self.handle, no_long_read)
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
async def write_value(self, value, with_response=False):
return await self.client.write_value(self.handle, value, with_response)
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
@@ -98,6 +104,7 @@ class CharacteristicProxy(AttributeProxy):
self.properties = properties
self.descriptors = []
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
@@ -108,9 +115,25 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return await self.client.subscribe(self, subscriber)
async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
@@ -140,7 +163,6 @@ class ProfileServiceProxy:
class Client:
def __init__(self, connection):
self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -194,7 +216,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return
return self.connection.att_mtu
# Send the request
self.mtu_exchange_done = True
@@ -207,8 +229,10 @@ class Client:
response
)
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
return self.mtu
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
@@ -570,12 +594,18 @@ class Client:
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
subscriber_set.remove(characteristic.handle)
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
await self.write_value(cccd, b'\x00\x00', with_response=True)
if not self.notification_subscribers and not self.indication_subscribers:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False):
'''
@@ -600,7 +630,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1:
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -622,7 +652,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.mtu - 1:
if len(part) < self.connection.att_mtu - 1:
break
offset += len(part)

View File

@@ -40,6 +40,12 @@ from .gatt import *
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -49,9 +55,8 @@ class Server(EventEmitter):
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
@@ -169,7 +174,7 @@ class Server(EventEmitter):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, force=False):
async def notify_subscriber(self, connection, attribute, value=None, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -184,13 +189,12 @@ class Server(EventEmitter):
logger.debug(f'not notifying, cccd={cccd.hex()}')
return
# Get the value
value = attribute.read_value(connection)
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
# Notify
notification = ATT_Handle_Value_Notification(
@@ -198,27 +202,9 @@ class Server(EventEmitter):
attribute_value = value
)
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, notification.to_bytes())
self.send_gatt_pdu(connection.handle, bytes(notification))
async def notify_subscribers(self, attribute, force=False):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Notify for each connection
if connections:
await asyncio.wait([
self.notify_subscriber(connection, attribute, force)
for connection in connections
])
async def indicate_subscriber(self, connection, attribute, force=False):
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -233,13 +219,12 @@ class Server(EventEmitter):
logger.debug(f'not indicating, cccd={cccd.hex()}')
return
# Get the value
value = attribute.read_value(connection)
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
# Indicate
indication = ATT_Handle_Value_Indication(
@@ -264,27 +249,32 @@ class Server(EventEmitter):
finally:
self.pending_confirmations[connection.handle] = None
async def indicate_subscribers(self, attribute):
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if subscribers.get(attribute.handle)
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Indicate for each connection
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait([
self.indicate_subscriber(connection, attribute)
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
])
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -325,9 +315,6 @@ class Server(EventEmitter):
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
#######################################################
# ATT handlers
#######################################################
@@ -347,12 +334,16 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(self, connection, request):
'''
@@ -369,7 +360,7 @@ class Server(EventEmitter):
return
# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
uuid_size = 0
for attribute in (
@@ -420,7 +411,7 @@ class Server(EventEmitter):
'''
# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -468,8 +459,7 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -482,7 +472,7 @@ class Server(EventEmitter):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -522,7 +512,7 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value))
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
@@ -541,7 +531,6 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
@@ -549,14 +538,14 @@ class Server(EventEmitter):
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
)
elif len(value) <= mtu - 1:
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
)
else:
part_size = min(mtu - 1, len(value) - request.value_offset)
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
)
@@ -585,8 +574,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
return
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -597,7 +585,7 @@ class Server(EventEmitter):
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251)
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]

View File

@@ -18,6 +18,8 @@
import logging
from colors import color
from bumble.smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
@@ -73,6 +75,9 @@ class PacketTracer:
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)

View File

@@ -155,6 +155,7 @@ SMP_CT2_AUTHREQ = 0b00100000
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -879,7 +880,7 @@ class Session:
)
)
)
async def derive_ltk(self):
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None
@@ -914,7 +915,7 @@ class Session:
csrk = bytes(16) # FIXME: testing
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
ilk = crypto.h7(
@@ -946,7 +947,7 @@ class Session:
csrk = bytes(16) # FIXME: testing
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
ilk = crypto.h7(
@@ -980,12 +981,7 @@ class Session:
self.peer_expected_distributions.remove(command_class)
logger.debug(f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}')
if not self.peer_expected_distributions:
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
# Nothing left to expect, we're done
asyncio.create_task(self.on_pairing())
self.on_peer_key_distribution_complete()
else:
logger.warn(color(f'!!! unexpected key distribution command: {command_class.__name__}', 'red'))
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
@@ -1006,12 +1002,23 @@ class Session:
self.connection.remove_listener('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh)
self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self):
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
asyncio.create_task(self.on_pairing())
def on_connection_encryption_change(self):
if self.connection.is_encrypted:
if self.is_responder:
# The responder distributes its keys first, the initiator later
self.distribute_keys()
# If we're not expecting key distributions from the peer, we're done
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self):
# Do as if the connection had just been encrypted
self.on_connection_encryption_change()

View File

@@ -56,18 +56,19 @@ async def open_usb_transport(spec):
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_ACL_OUT = 0x02
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device):
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
@@ -116,7 +117,7 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
USB_ENDPOINT_ACL_OUT,
self.acl_out,
packet[1:],
callback=self.on_packet_sent
)
@@ -152,10 +153,12 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device):
def __init__(self, context, device, acl_in, events_in):
super().__init__()
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
@@ -172,7 +175,7 @@ async def open_usb_transport(spec):
# Set up transfer objects for input
self.events_in_transfer = device.getTransfer()
self.events_in_transfer.setInterrupt(
USB_ENDPOINT_EVENTS_IN,
self.events_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET
@@ -181,7 +184,7 @@ async def open_usb_transport(spec):
self.acl_in_transfer = device.getTransfer()
self.acl_in_transfer.setBulk(
USB_ENDPOINT_ACL_IN,
self.acl_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET
@@ -248,7 +251,7 @@ async def open_usb_transport(spec):
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, source, sink):
def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
@@ -257,6 +260,10 @@ async def open_usb_transport(spec):
# Get exclusive access
device.claimInterface(interface)
# Set the alternate setting if not the default
if setting != 0:
device.setInterfaceAltSetting(interface, setting)
# The source and sink can now start
source.start()
sink.start()
@@ -313,10 +320,63 @@ async def open_usb_transport(spec):
raise ValueError('device not found')
logger.debug(f'USB Device: {found}')
device = found.open()
# Use the first interface
interface = 0
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
setting.getClass() != USB_DEVICE_CLASS_WIRELESS_CONTROLLER or
setting.getSubClass() != USB_DEVICE_SUBCLASS_RF_CONTROLLER or
setting.getProtocol() != USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
):
continue
events_in = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT:
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None:
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in
)
else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}')
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
f'interface={interface}, '
f'setting={setting}, '
f'acl_in=0x{acl_in:02X}, '
f'acl_out=0x{acl_out:02X}, '
f'events_in=0x{events_in:02X}, '
)
device = found.open()
# Detach the kernel driver if supported and needed
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
@@ -329,21 +389,21 @@ async def open_usb_transport(spec):
# Set the configuration if needed
try:
configuration = device.getConfiguration()
logger.debug(f'current configuration = {configuration}')
current_configuration = device.getConfiguration()
logger.debug(f'current configuration = {current_configuration}')
except usb1.USBError:
configuration = 0
current_configuration = 0
if configuration != 1:
if current_configuration != configuration:
try:
logger.debug('setting configuration 1')
device.setConfiguration(1)
logger.debug(f'setting configuration {configuration}')
device.setConfiguration(configuration)
except usb1.USBError:
logger.warning('failed to set configuration 1')
logger.warning('failed to set configuration')
source = UsbPacketSource(context, device)
sink = UsbPacketSink(device)
return UsbTransport(context, device, interface, source, sink)
source = UsbPacketSource(context, device, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
context.close()

View File

@@ -22,6 +22,7 @@ import struct
import pytest
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -53,29 +54,29 @@ def basic_check(x):
parsed = ATT_PDU.from_bytes(pdu)
x_str = str(x)
parsed_str = str(parsed)
assert(x_str == parsed_str)
assert x_str == parsed_str
# -----------------------------------------------------------------------------
def test_UUID():
u = UUID.from_16_bits(0x7788)
assert(str(u) == 'UUID-16:7788')
assert str(u) == 'UUID-16:7788'
u = UUID.from_32_bits(0x11223344)
assert(str(u) == 'UUID-32:11223344')
assert str(u) == 'UUID-32:11223344'
u = UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
assert(str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
v = UUID(str(u))
assert(str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
w = UUID.from_bytes(v.to_bytes())
assert(str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
u1 = UUID.from_16_bits(0x1234)
b1 = u1.to_bytes(force_128 = True)
u2 = UUID.from_bytes(b1)
assert(u1 == u2)
assert u1 == u2
u3 = UUID.from_16_bits(0x180a)
assert(str(u3) == 'UUID-16:180A (Device Information)')
assert str(u3) == 'UUID-16:180A (Device Information)'
# -----------------------------------------------------------------------------
@@ -98,6 +99,122 @@ def test_ATT_Read_By_Group_Type_Request():
basic_check(pdu)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_characteristic_encoding():
class Foo(Characteristic):
def encode_value(self, value):
return bytes([value])
def decode_value(self, value_bytes):
return value_bytes[0]
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123)
x = c.read_value(None)
assert x == bytes([123])
c.write_value(None, bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
def __init__(self, characteristic):
super().__init__(
characteristic.client,
characteristic.handle,
characteristic.end_group_handle,
characteristic.uuid,
characteristic.properties
)
def encode_value(self, value):
return bytes([value])
def decode_value(self, value_bytes):
return value_bytes[0]
[client, server] = LinkedDevices().devices[:2]
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123])
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic]
)
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic.uuid)
assert len(c) == 1
c = c[0]
cp = FooProxy(c)
v = await cp.read_value()
assert v == 123
await cp.write_value(124)
await async_barrier()
assert characteristic.value == bytes([124])
last_change = None
def on_change(value):
nonlocal last_change
last_change = value
await c.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value
last_change = None
await server.notify_subscribers(characteristic, value=bytes([125]))
await async_barrier()
assert last_change == bytes([125])
last_change = None
await c.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
await cp.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value[0]
last_change = None
await server.notify_subscribers(characteristic, value=bytes([126]))
await async_barrier()
assert last_change == 126
last_change = None
await cp.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
cd = DelegatedCharacteristicAdapter(c, decode=lambda x: x[0])
await cd.subscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change == characteristic.value[0]
last_change = None
await cd.unsubscribe(on_change)
await server.notify_subscribers(characteristic)
await async_barrier()
assert last_change is None
# -----------------------------------------------------------------------------
def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
@@ -106,21 +223,21 @@ def test_CharacteristicAdapter():
a = CharacteristicAdapter(c)
value = a.read_value(None)
assert(value == v)
assert value == v
v = bytes([3, 4, 5])
a.write_value(None, v)
assert(c.value == v)
assert c.value == v
# Simple delegated adapter
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
value = a.read_value(None)
assert(value == bytes(reversed(v)))
assert value == bytes(reversed(v))
v = bytes([3, 4, 5])
a.write_value(None, v)
assert(a.value == bytes(reversed(v)))
assert a.value == bytes(reversed(v))
# Packed adapter with single element format
v = 1234
@@ -129,10 +246,10 @@ def test_CharacteristicAdapter():
a = PackedCharacteristicAdapter(c, '>H')
value = a.read_value(None)
assert(value == pv)
assert value == pv
c.value = None
a.write_value(None, pv)
assert(a.value == v)
assert a.value == v
# Packed adapter with multi-element format
v1 = 1234
@@ -142,10 +259,10 @@ def test_CharacteristicAdapter():
a = PackedCharacteristicAdapter(c, '>HH')
value = a.read_value(None)
assert(value == pv)
assert value == pv
c.value = None
a.write_value(None, pv)
assert(a.value == (v1, v2))
assert a.value == (v1, v2)
# Mapped adapter
v1 = 1234
@@ -156,10 +273,10 @@ def test_CharacteristicAdapter():
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = a.read_value(None)
assert(value == pv)
assert value == pv
c.value = None
a.write_value(None, pv)
assert(a.value == mapped)
assert a.value == mapped
# UTF-8 adapter
v = 'Hello π'
@@ -168,10 +285,10 @@ def test_CharacteristicAdapter():
a = UTF8CharacteristicAdapter(c)
value = a.read_value(None)
assert(value == ev)
assert value == ev
c.value = None
a.write_value(None, ev)
assert(a.value == v)
assert a.value == v
# -----------------------------------------------------------------------------
@@ -179,24 +296,25 @@ def test_CharacteristicValue():
b = bytes([1, 2, 3])
c = CharacteristicValue(read=lambda _: b)
x = c.read(None)
assert(x == b)
assert x == b
result = []
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
z = object()
c.write(z, b)
assert(result == [(z, b)])
assert result == [(z, b)]
# -----------------------------------------------------------------------------
class TwoDevices:
class LinkedDevices:
def __init__(self):
self.connections = [None, None]
self.connections = [None, None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
Controller('C2', link = self.link),
Controller('C3', link = self.link)
]
self.devices = [
Device(
@@ -204,12 +322,16 @@ class TwoDevices:
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
address = 'F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
)
]
self.paired = [None, None]
self.paired = [None, None, None]
# -----------------------------------------------------------------------------
@@ -222,7 +344,7 @@ async def async_barrier():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -265,41 +387,41 @@ async def test_read_write():
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert(len(c) == 1)
assert len(c) == 1
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert(len(c) == 1)
assert len(c) == 1
c2 = c[0]
v1 = await peer.read_value(c1)
assert(v1 == b'')
assert v1 == b''
b = bytes([1, 2, 3])
await peer.write_value(c1, b)
await async_barrier()
assert(characteristic1.value == b)
assert characteristic1.value == b
v1 = await peer.read_value(c1)
assert(v1 == b)
assert(type(characteristic1._last_value) is tuple)
assert(len(characteristic1._last_value) == 2)
assert(str(characteristic1._last_value[0].peer_address) == str(client.random_address))
assert(characteristic1._last_value[1] == b)
assert v1 == b
assert type(characteristic1._last_value is tuple)
assert len(characteristic1._last_value) == 2
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address)
assert characteristic1._last_value[1] == b
bb = bytes([3, 4, 5, 6])
characteristic1.value = bb
v1 = await peer.read_value(c1)
assert(v1 == bb)
assert v1 == bb
await peer.write_value(c2, b)
await async_barrier()
assert(type(characteristic2._last_value) is tuple)
assert(len(characteristic2._last_value) == 2)
assert(str(characteristic2._last_value[0].peer_address) == str(client.random_address))
assert(characteristic2._last_value[1] == b)
assert type(characteristic2._last_value is tuple)
assert len(characteristic2._last_value) == 2
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address)
assert characteristic2._last_value[1] == b
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write2():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic(
@@ -324,32 +446,32 @@ async def test_read_write2():
await peer.discover_services()
c = peer.get_services_by_uuid(service1.uuid)
assert(len(c) == 1)
assert len(c) == 1
s = c[0]
await s.discover_characteristics()
c = s.get_characteristics_by_uuid(characteristic1.uuid)
assert(len(c) == 1)
assert len(c) == 1
c1 = c[0]
v1 = await c1.read_value()
assert(v1 == v)
assert v1 == v
a1 = PackedCharacteristicAdapter(c1, '>I')
v1 = await a1.read_value()
assert(v1 == struct.unpack('>I', v)[0])
assert v1 == struct.unpack('>I', v)[0]
b = bytes([0x55, 0x66, 0x77, 0x88])
await a1.write_value(struct.unpack('>I', b)[0])
await async_barrier()
assert(characteristic1.value == b)
assert characteristic1.value == b
v1 = await a1.read_value()
assert(v1 == struct.unpack('>I', b)[0])
assert v1 == struct.unpack('>I', b)[0]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_notify():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -410,13 +532,13 @@ async def test_subscribe_notify():
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert(len(c) == 1)
assert len(c) == 1
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert(len(c) == 1)
assert len(c) == 1
c2 = c[0]
c = peer.get_characteristics_by_uuid(characteristic3.uuid)
assert(len(c) == 1)
assert len(c) == 1
c3 = c[0]
c1._called = False
@@ -429,23 +551,32 @@ async def test_subscribe_notify():
c1.on('update', on_c1_update)
await peer.subscribe(c1)
await async_barrier()
assert(server._last_subscription[1] == characteristic1)
assert(server._last_subscription[2])
assert(not server._last_subscription[3])
assert(characteristic1._last_subscription[1])
assert(not characteristic1._last_subscription[2])
assert server._last_subscription[1] == characteristic1
assert server._last_subscription[2]
assert not server._last_subscription[3]
assert characteristic1._last_subscription[1]
assert not characteristic1._last_subscription[2]
await server.indicate_subscribers(characteristic1)
await async_barrier()
assert(not c1._called)
assert not c1._called
await server.notify_subscribers(characteristic1)
await async_barrier()
assert(c1._called)
assert(c1._last_update == characteristic1.value)
assert c1._called
assert c1._last_update == characteristic1.value
c1._called = False
c1._last_update = None
c1_value = characteristic1.value
await server.notify_subscribers(characteristic1, bytes([0, 1, 2]))
await async_barrier()
assert c1._called
assert c1._last_update == bytes([0, 1, 2])
assert characteristic1.value == c1_value
c1._called = False
await peer.unsubscribe(c1)
await server.notify_subscribers(characteristic1)
assert(not c1._called)
assert not c1._called
c2._called = False
c2._last_update = None
@@ -458,17 +589,17 @@ async def test_subscribe_notify():
await async_barrier()
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert(not c2._called)
assert not c2._called
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert(c2._called)
assert(c2._last_update == characteristic2.value)
assert c2._called
assert c2._last_update == characteristic2.value
c2._called = False
await peer.unsubscribe(c2, on_c2_update)
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await async_barrier()
assert(not c2._called)
assert not c2._called
def on_c3_update(value):
c3._called = True
@@ -483,17 +614,17 @@ async def test_subscribe_notify():
await async_barrier()
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert(c3._called)
assert(c3._last_update == characteristic3.value)
assert(c3._called_2)
assert(c3._last_update_2 == characteristic3.value)
assert c3._called
assert c3._last_update == characteristic3.value
assert c3._called_2
assert c3._last_update_2 == characteristic3.value
characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert(c3._called)
assert(c3._last_update == characteristic3.value)
assert(c3._called_2)
assert(c3._last_update_2 == characteristic3.value)
assert c3._called
assert c3._last_update == characteristic3.value
assert c3._called_2
assert c3._last_update_2 == characteristic3.value
c3._called = False
c3._called_2 = False
@@ -501,8 +632,44 @@ async def test_subscribe_notify():
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier()
assert(not c3._called)
assert(not c3._called_2)
assert not c3._called
assert not c3._called_2
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]
d3.gatt_server.max_mtu = 100
d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)
await d1.power_on()
await d2.power_on()
await d3.power_on()
d1_connection = await d1.connect(d3.random_address)
assert len(d3_connections) == 1
assert d3_connections[0] is not None
d2_connection = await d2.connect(d3.random_address)
assert len(d3_connections) == 2
assert d3_connections[1] is not None
d1_peer = Peer(d1_connection)
d2_peer = Peer(d2_connection)
d1_client_mtu = await d1_peer.request_mtu(220)
assert d1_client_mtu == 100
assert d1_connection.att_mtu == 100
d2_client_mtu = await d2_peer.request_mtu(50)
assert d2_client_mtu == 50
assert d2_connection.att_mtu == 50
# -----------------------------------------------------------------------------
@@ -510,6 +677,9 @@ async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_characteristic_encoding()
await test_mtu_exchange()
# -----------------------------------------------------------------------------
if __name__ == '__main__':

View File

@@ -246,8 +246,7 @@ IO_CAP = [
SC = [False, True]
MITM = [False, True]
# Key distribution is a 4-bit bitmask
# IdKey is necessary for current SMP structure
KEY_DIST = [i for i in range(16) if (i & SMP_ID_KEY_DISTRIBUTION_FLAG)]
KEY_DIST = range(16)
@pytest.mark.asyncio
@pytest.mark.parametrize('io_cap, sc, mitm, key_dist',