Merge pull request #40 from google/gbg/gatt-mtu

maintain the att mtu only at the connection level
This commit is contained in:
Gilles Boccon-Gibod
2022-10-05 13:53:47 -07:00
committed by GitHub
3 changed files with 87 additions and 45 deletions
+7 -6
View File
@@ -163,7 +163,6 @@ class ProfileServiceProxy:
class Client:
def __init__(self, connection):
self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -217,7 +216,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return
return self.connection.att_mtu
# Send the request
self.mtu_exchange_done = True
@@ -230,8 +229,10 @@ class Client:
response
)
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
return self.mtu
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
@@ -629,7 +630,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1:
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -651,7 +652,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.mtu - 1:
if len(part) < self.connection.att_mtu - 1:
break
offset += len(part)
+29 -30
View File
@@ -40,6 +40,12 @@ from .gatt import *
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -49,9 +55,8 @@ class Server(EventEmitter):
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
@@ -188,9 +193,8 @@ class Server(EventEmitter):
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
# Notify
notification = ATT_Handle_Value_Notification(
@@ -219,9 +223,8 @@ class Server(EventEmitter):
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
# Indicate
indication = ATT_Handle_Value_Indication(
@@ -272,8 +275,6 @@ class Server(EventEmitter):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -314,9 +315,6 @@ class Server(EventEmitter):
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
#######################################################
# ATT handlers
#######################################################
@@ -336,12 +334,16 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(self, connection, request):
'''
@@ -358,7 +360,7 @@ class Server(EventEmitter):
return
# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
uuid_size = 0
for attribute in (
@@ -409,7 +411,7 @@ class Server(EventEmitter):
'''
# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -457,8 +459,7 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -471,7 +472,7 @@ class Server(EventEmitter):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -511,7 +512,7 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value))
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
@@ -530,7 +531,6 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
@@ -538,14 +538,14 @@ class Server(EventEmitter):
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
)
elif len(value) <= mtu - 1:
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
)
else:
part_size = min(mtu - 1, len(value) - request.value_offset)
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
)
@@ -574,8 +574,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
return
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
@@ -586,7 +585,7 @@ class Server(EventEmitter):
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251)
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
+51 -9
View File
@@ -131,7 +131,7 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes):
return value_bytes[0]
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -306,14 +306,15 @@ def test_CharacteristicValue():
# -----------------------------------------------------------------------------
class TwoDevices:
class LinkedDevices:
def __init__(self):
self.connections = [None, None]
self.connections = [None, None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
Controller('C2', link = self.link),
Controller('C3', link = self.link)
]
self.devices = [
Device(
@@ -321,12 +322,16 @@ class TwoDevices:
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
address = 'F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
)
]
self.paired = [None, None]
self.paired = [None, None, None]
# -----------------------------------------------------------------------------
@@ -339,7 +344,7 @@ async def async_barrier():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -416,7 +421,7 @@ async def test_read_write():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write2():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic(
@@ -466,7 +471,7 @@ async def test_read_write2():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_notify():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -631,12 +636,49 @@ async def test_subscribe_notify():
assert not c3._called_2
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]
d3.gatt_server.max_mtu = 100
d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)
await d1.power_on()
await d2.power_on()
await d3.power_on()
d1_connection = await d1.connect(d3.random_address)
assert len(d3_connections) == 1
assert d3_connections[0] is not None
d2_connection = await d2.connect(d3.random_address)
assert len(d3_connections) == 2
assert d3_connections[1] is not None
d1_peer = Peer(d1_connection)
d2_peer = Peer(d2_connection)
d1_client_mtu = await d1_peer.request_mtu(220)
assert d1_client_mtu == 100
assert d1_connection.att_mtu == 100
d2_client_mtu = await d2_peer.request_mtu(50)
assert d2_client_mtu == 50
assert d2_connection.att_mtu == 50
# -----------------------------------------------------------------------------
async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_characteristic_encoding()
await test_mtu_exchange()
# -----------------------------------------------------------------------------