Merge pull request #40 from google/gbg/gatt-mtu

maintain the att mtu only at the connection level
This commit is contained in:
Gilles Boccon-Gibod
2022-10-05 13:53:47 -07:00
committed by GitHub
3 changed files with 87 additions and 45 deletions
+7 -6
View File
@@ -163,7 +163,6 @@ class ProfileServiceProxy:
class Client: class Client:
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None self.pending_request = None
@@ -217,7 +216,7 @@ class Client:
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:
return return self.connection.att_mtu
# Send the request # Send the request
self.mtu_exchange_done = True self.mtu_exchange_done = True
@@ -230,8 +229,10 @@ class Client:
response response
) )
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu) # Compute the final MTU
return self.mtu self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
@@ -629,7 +630,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller # If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that # specifically asked not to do that
attribute_value = response.attribute_value attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1: if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value') logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value) offset = len(attribute_value)
while True: while True:
@@ -651,7 +652,7 @@ class Client:
part = response.part_attribute_value part = response.part_attribute_value
attribute_value += part attribute_value += part
if len(part) < self.mtu - 1: if len(part) < self.connection.att_mtu - 1:
break break
offset += len(part) offset += len(part)
+27 -28
View File
@@ -40,6 +40,12 @@ from .gatt import *
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GATT Server # GATT Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -49,9 +55,8 @@ class Server(EventEmitter):
self.device = device self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None) self.pending_confirmations = defaultdict(lambda: None)
@@ -188,9 +193,8 @@ class Server(EventEmitter):
value = attribute.read_value(connection) if value is None else attribute.encode_value(value) value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed # Truncate if needed
mtu = self.get_mtu(connection) if len(value) > connection.att_mtu - 3:
if len(value) > mtu - 3: value = value[:connection.att_mtu - 3]
value = value[:mtu - 3]
# Notify # Notify
notification = ATT_Handle_Value_Notification( notification = ATT_Handle_Value_Notification(
@@ -219,9 +223,8 @@ class Server(EventEmitter):
value = attribute.read_value(connection) if value is None else attribute.encode_value(value) value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed # Truncate if needed
mtu = self.get_mtu(connection) if len(value) > connection.att_mtu - 3:
if len(value) > mtu - 3: value = value[:connection.att_mtu - 3]
value = value[:mtu - 3]
# Indicate # Indicate
indication = ATT_Handle_Value_Indication( indication = ATT_Handle_Value_Indication(
@@ -272,8 +275,6 @@ class Server(EventEmitter):
return await self.notify_or_indicate_subscribers(True, attribute, value, force) return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection): def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers: if connection.handle in self.subscribers:
del self.subscribers[connection.handle] del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores: if connection.handle in self.indication_semaphores:
@@ -314,9 +315,6 @@ class Server(EventEmitter):
# Just ignore # Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}') logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
####################################################### #######################################################
# ATT handlers # ATT handlers
####################################################### #######################################################
@@ -336,12 +334,16 @@ class Server(EventEmitter):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
''' '''
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu)) self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu)) # Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device # Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu) self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(self, connection, request): def on_att_find_information_request(self, connection, request):
''' '''
@@ -358,7 +360,7 @@ class Server(EventEmitter):
return return
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
uuid_size = 0 uuid_size = 0
for attribute in ( for attribute in (
@@ -409,7 +411,7 @@ class Server(EventEmitter):
''' '''
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute for attribute in self.attributes if
@@ -457,8 +459,7 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
mtu = self.get_mtu(connection) pdu_space_available = connection.att_mtu - 2
pdu_space_available = mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute for attribute in self.attributes if
@@ -471,7 +472,7 @@ class Server(EventEmitter):
# Check the attribute value size # Check the attribute value size
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253) max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
@@ -511,7 +512,7 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions # TODO: check permissions
value = attribute.read_value(connection) value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value)) value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response( response = ATT_Read_Response(
attribute_value = value[:value_size] attribute_value = value[:value_size]
) )
@@ -530,7 +531,6 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions # TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection) value = attribute.read_value(connection)
if request.value_offset > len(value): if request.value_offset > len(value):
response = ATT_Error_Response( response = ATT_Error_Response(
@@ -538,14 +538,14 @@ class Server(EventEmitter):
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR error_code = ATT_INVALID_OFFSET_ERROR
) )
elif len(value) <= mtu - 1: elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
) )
else: else:
part_size = min(mtu - 1, len(value) - request.value_offset) part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response( response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size] part_attribute_value = value[request.value_offset:request.value_offset + part_size]
) )
@@ -574,8 +574,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
return return
mtu = self.get_mtu(connection) pdu_space_available = connection.att_mtu - 2
pdu_space_available = mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute for attribute in self.attributes if
@@ -586,7 +585,7 @@ class Server(EventEmitter):
): ):
# Check the attribute value size # Check the attribute value size
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251) max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
+51 -9
View File
@@ -131,7 +131,7 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes): def decode_value(self, value_bytes):
return value_bytes[0] return value_bytes[0]
[client, server] = TwoDevices().devices [client, server] = LinkedDevices().devices[:2]
characteristic = Characteristic( characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -306,14 +306,15 @@ def test_CharacteristicValue():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TwoDevices: class LinkedDevices:
def __init__(self): def __init__(self):
self.connections = [None, None] self.connections = [None, None, None]
self.link = LocalLink() self.link = LocalLink()
self.controllers = [ self.controllers = [
Controller('C1', link = self.link), Controller('C1', link = self.link),
Controller('C2', link = self.link) Controller('C2', link = self.link),
Controller('C3', link = self.link)
] ]
self.devices = [ self.devices = [
Device( Device(
@@ -321,12 +322,16 @@ class TwoDevices:
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
), ),
Device( Device(
address = 'F5:F4:F3:F2:F1:F0', address = 'F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
) )
] ]
self.paired = [None, None] self.paired = [None, None, None]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -339,7 +344,7 @@ async def async_barrier():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_read_write(): async def test_read_write():
[client, server] = TwoDevices().devices [client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -416,7 +421,7 @@ async def test_read_write():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_read_write2(): async def test_read_write2():
[client, server] = TwoDevices().devices [client, server] = LinkedDevices().devices[:2]
v = bytes([0x11, 0x22, 0x33, 0x44]) v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic( characteristic1 = Characteristic(
@@ -466,7 +471,7 @@ async def test_read_write2():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscribe_notify(): async def test_subscribe_notify():
[client, server] = TwoDevices().devices [client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -631,12 +636,49 @@ async def test_subscribe_notify():
assert not c3._called_2 assert not c3._called_2
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]
d3.gatt_server.max_mtu = 100
d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)
await d1.power_on()
await d2.power_on()
await d3.power_on()
d1_connection = await d1.connect(d3.random_address)
assert len(d3_connections) == 1
assert d3_connections[0] is not None
d2_connection = await d2.connect(d3.random_address)
assert len(d3_connections) == 2
assert d3_connections[1] is not None
d1_peer = Peer(d1_connection)
d2_peer = Peer(d2_connection)
d1_client_mtu = await d1_peer.request_mtu(220)
assert d1_client_mtu == 100
assert d1_connection.att_mtu == 100
d2_client_mtu = await d2_peer.request_mtu(50)
assert d2_client_mtu == 50
assert d2_connection.att_mtu == 50
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
await test_read_write() await test_read_write()
await test_read_write2() await test_read_write2()
await test_subscribe_notify() await test_subscribe_notify()
await test_characteristic_encoding() await test_characteristic_encoding()
await test_mtu_exchange()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------