forked from auracaster/bumble_mirror
Merge pull request #40 from google/gbg/gatt-mtu
maintain the att mtu only at the connection level
This commit is contained in:
@@ -163,7 +163,6 @@ class ProfileServiceProxy:
|
|||||||
class Client:
|
class Client:
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.mtu = ATT_DEFAULT_MTU
|
|
||||||
self.mtu_exchange_done = False
|
self.mtu_exchange_done = False
|
||||||
self.request_semaphore = asyncio.Semaphore(1)
|
self.request_semaphore = asyncio.Semaphore(1)
|
||||||
self.pending_request = None
|
self.pending_request = None
|
||||||
@@ -217,7 +216,7 @@ class Client:
|
|||||||
|
|
||||||
# We can only send one request per connection
|
# We can only send one request per connection
|
||||||
if self.mtu_exchange_done:
|
if self.mtu_exchange_done:
|
||||||
return
|
return self.connection.att_mtu
|
||||||
|
|
||||||
# Send the request
|
# Send the request
|
||||||
self.mtu_exchange_done = True
|
self.mtu_exchange_done = True
|
||||||
@@ -230,8 +229,10 @@ class Client:
|
|||||||
response
|
response
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
|
# Compute the final MTU
|
||||||
return self.mtu
|
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
|
||||||
|
|
||||||
|
return self.connection.att_mtu
|
||||||
|
|
||||||
def get_services_by_uuid(self, uuid):
|
def get_services_by_uuid(self, uuid):
|
||||||
return [service for service in self.services if service.uuid == uuid]
|
return [service for service in self.services if service.uuid == uuid]
|
||||||
@@ -629,7 +630,7 @@ class Client:
|
|||||||
# If the value is the max size for the MTU, try to read more unless the caller
|
# If the value is the max size for the MTU, try to read more unless the caller
|
||||||
# specifically asked not to do that
|
# specifically asked not to do that
|
||||||
attribute_value = response.attribute_value
|
attribute_value = response.attribute_value
|
||||||
if not no_long_read and len(attribute_value) == self.mtu - 1:
|
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
|
||||||
logger.debug('using READ BLOB to get the rest of the value')
|
logger.debug('using READ BLOB to get the rest of the value')
|
||||||
offset = len(attribute_value)
|
offset = len(attribute_value)
|
||||||
while True:
|
while True:
|
||||||
@@ -651,7 +652,7 @@ class Client:
|
|||||||
part = response.part_attribute_value
|
part = response.part_attribute_value
|
||||||
attribute_value += part
|
attribute_value += part
|
||||||
|
|
||||||
if len(part) < self.mtu - 1:
|
if len(part) < self.connection.att_mtu - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
offset += len(part)
|
offset += len(part)
|
||||||
|
|||||||
+27
-28
@@ -40,6 +40,12 @@ from .gatt import *
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# GATT Server
|
# GATT Server
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -49,9 +55,8 @@ class Server(EventEmitter):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.attributes = [] # Attributes, ordered by increasing handle values
|
self.attributes = [] # Attributes, ordered by increasing handle values
|
||||||
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
||||||
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
|
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
|
||||||
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
|
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
|
||||||
self.mtus = {} # Map of ATT MTU values by connection handle
|
|
||||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||||
self.pending_confirmations = defaultdict(lambda: None)
|
self.pending_confirmations = defaultdict(lambda: None)
|
||||||
|
|
||||||
@@ -188,9 +193,8 @@ class Server(EventEmitter):
|
|||||||
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
mtu = self.get_mtu(connection)
|
if len(value) > connection.att_mtu - 3:
|
||||||
if len(value) > mtu - 3:
|
value = value[:connection.att_mtu - 3]
|
||||||
value = value[:mtu - 3]
|
|
||||||
|
|
||||||
# Notify
|
# Notify
|
||||||
notification = ATT_Handle_Value_Notification(
|
notification = ATT_Handle_Value_Notification(
|
||||||
@@ -219,9 +223,8 @@ class Server(EventEmitter):
|
|||||||
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
mtu = self.get_mtu(connection)
|
if len(value) > connection.att_mtu - 3:
|
||||||
if len(value) > mtu - 3:
|
value = value[:connection.att_mtu - 3]
|
||||||
value = value[:mtu - 3]
|
|
||||||
|
|
||||||
# Indicate
|
# Indicate
|
||||||
indication = ATT_Handle_Value_Indication(
|
indication = ATT_Handle_Value_Indication(
|
||||||
@@ -272,8 +275,6 @@ class Server(EventEmitter):
|
|||||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||||
|
|
||||||
def on_disconnection(self, connection):
|
def on_disconnection(self, connection):
|
||||||
if connection.handle in self.mtus:
|
|
||||||
del self.mtus[connection.handle]
|
|
||||||
if connection.handle in self.subscribers:
|
if connection.handle in self.subscribers:
|
||||||
del self.subscribers[connection.handle]
|
del self.subscribers[connection.handle]
|
||||||
if connection.handle in self.indication_semaphores:
|
if connection.handle in self.indication_semaphores:
|
||||||
@@ -314,9 +315,6 @@ class Server(EventEmitter):
|
|||||||
# Just ignore
|
# Just ignore
|
||||||
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
|
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
|
||||||
|
|
||||||
def get_mtu(self, connection):
|
|
||||||
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# ATT handlers
|
# ATT handlers
|
||||||
#######################################################
|
#######################################################
|
||||||
@@ -336,12 +334,16 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
||||||
'''
|
'''
|
||||||
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
|
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
|
||||||
self.mtus[connection.handle] = mtu
|
|
||||||
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
|
# Compute the final MTU
|
||||||
|
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
|
||||||
|
mtu = min(self.max_mtu, request.client_rx_mtu)
|
||||||
|
|
||||||
# Notify the device
|
# Notify the device
|
||||||
self.device.on_connection_att_mtu_update(connection.handle, mtu)
|
self.device.on_connection_att_mtu_update(connection.handle, mtu)
|
||||||
|
else:
|
||||||
|
logger.warning('invalid client_rx_mtu received, MTU not changed')
|
||||||
|
|
||||||
def on_att_find_information_request(self, connection, request):
|
def on_att_find_information_request(self, connection, request):
|
||||||
'''
|
'''
|
||||||
@@ -358,7 +360,7 @@ class Server(EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Build list of returned attributes
|
# Build list of returned attributes
|
||||||
pdu_space_available = self.get_mtu(connection) - 2
|
pdu_space_available = connection.att_mtu - 2
|
||||||
attributes = []
|
attributes = []
|
||||||
uuid_size = 0
|
uuid_size = 0
|
||||||
for attribute in (
|
for attribute in (
|
||||||
@@ -409,7 +411,7 @@ class Server(EventEmitter):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
# Build list of returned attributes
|
# Build list of returned attributes
|
||||||
pdu_space_available = self.get_mtu(connection) - 2
|
pdu_space_available = connection.att_mtu - 2
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -457,8 +459,7 @@ class Server(EventEmitter):
|
|||||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||||
'''
|
'''
|
||||||
|
|
||||||
mtu = self.get_mtu(connection)
|
pdu_space_available = connection.att_mtu - 2
|
||||||
pdu_space_available = mtu - 2
|
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -471,7 +472,7 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
# Check the attribute value size
|
# Check the attribute value size
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
max_attribute_size = min(mtu - 4, 253)
|
max_attribute_size = min(connection.att_mtu - 4, 253)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
attribute_value = attribute_value[:max_attribute_size]
|
attribute_value = attribute_value[:max_attribute_size]
|
||||||
@@ -511,7 +512,7 @@ class Server(EventEmitter):
|
|||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
# TODO: check permissions
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
value_size = min(self.get_mtu(connection) - 1, len(value))
|
value_size = min(connection.att_mtu - 1, len(value))
|
||||||
response = ATT_Read_Response(
|
response = ATT_Read_Response(
|
||||||
attribute_value = value[:value_size]
|
attribute_value = value[:value_size]
|
||||||
)
|
)
|
||||||
@@ -530,7 +531,6 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
if attribute := self.get_attribute(request.attribute_handle):
|
if attribute := self.get_attribute(request.attribute_handle):
|
||||||
# TODO: check permissions
|
# TODO: check permissions
|
||||||
mtu = self.get_mtu(connection)
|
|
||||||
value = attribute.read_value(connection)
|
value = attribute.read_value(connection)
|
||||||
if request.value_offset > len(value):
|
if request.value_offset > len(value):
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
@@ -538,14 +538,14 @@ class Server(EventEmitter):
|
|||||||
attribute_handle_in_error = request.attribute_handle,
|
attribute_handle_in_error = request.attribute_handle,
|
||||||
error_code = ATT_INVALID_OFFSET_ERROR
|
error_code = ATT_INVALID_OFFSET_ERROR
|
||||||
)
|
)
|
||||||
elif len(value) <= mtu - 1:
|
elif len(value) <= connection.att_mtu - 1:
|
||||||
response = ATT_Error_Response(
|
response = ATT_Error_Response(
|
||||||
request_opcode_in_error = request.op_code,
|
request_opcode_in_error = request.op_code,
|
||||||
attribute_handle_in_error = request.attribute_handle,
|
attribute_handle_in_error = request.attribute_handle,
|
||||||
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
|
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
part_size = min(mtu - 1, len(value) - request.value_offset)
|
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
|
||||||
response = ATT_Read_Blob_Response(
|
response = ATT_Read_Blob_Response(
|
||||||
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
|
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
|
||||||
)
|
)
|
||||||
@@ -574,8 +574,7 @@ class Server(EventEmitter):
|
|||||||
self.send_response(connection, response)
|
self.send_response(connection, response)
|
||||||
return
|
return
|
||||||
|
|
||||||
mtu = self.get_mtu(connection)
|
pdu_space_available = connection.att_mtu - 2
|
||||||
pdu_space_available = mtu - 2
|
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute in (
|
for attribute in (
|
||||||
attribute for attribute in self.attributes if
|
attribute for attribute in self.attributes if
|
||||||
@@ -586,7 +585,7 @@ class Server(EventEmitter):
|
|||||||
):
|
):
|
||||||
# Check the attribute value size
|
# Check the attribute value size
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
max_attribute_size = min(mtu - 6, 251)
|
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||||
if len(attribute_value) > max_attribute_size:
|
if len(attribute_value) > max_attribute_size:
|
||||||
# We need to truncate
|
# We need to truncate
|
||||||
attribute_value = attribute_value[:max_attribute_size]
|
attribute_value = attribute_value[:max_attribute_size]
|
||||||
|
|||||||
+51
-9
@@ -131,7 +131,7 @@ async def test_characteristic_encoding():
|
|||||||
def decode_value(self, value_bytes):
|
def decode_value(self, value_bytes):
|
||||||
return value_bytes[0]
|
return value_bytes[0]
|
||||||
|
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
characteristic = Characteristic(
|
characteristic = Characteristic(
|
||||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
@@ -306,14 +306,15 @@ def test_CharacteristicValue():
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class TwoDevices:
|
class LinkedDevices:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.connections = [None, None]
|
self.connections = [None, None, None]
|
||||||
|
|
||||||
self.link = LocalLink()
|
self.link = LocalLink()
|
||||||
self.controllers = [
|
self.controllers = [
|
||||||
Controller('C1', link = self.link),
|
Controller('C1', link = self.link),
|
||||||
Controller('C2', link = self.link)
|
Controller('C2', link = self.link),
|
||||||
|
Controller('C3', link = self.link)
|
||||||
]
|
]
|
||||||
self.devices = [
|
self.devices = [
|
||||||
Device(
|
Device(
|
||||||
@@ -321,12 +322,16 @@ class TwoDevices:
|
|||||||
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
|
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
|
||||||
),
|
),
|
||||||
Device(
|
Device(
|
||||||
address = 'F5:F4:F3:F2:F1:F0',
|
address = 'F1:F2:F3:F4:F5:F6',
|
||||||
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
|
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
|
||||||
|
),
|
||||||
|
Device(
|
||||||
|
address = 'F2:F3:F4:F5:F6:F7',
|
||||||
|
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.paired = [None, None]
|
self.paired = [None, None, None]
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -339,7 +344,7 @@ async def async_barrier():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_read_write():
|
async def test_read_write():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
@@ -416,7 +421,7 @@ async def test_read_write():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_read_write2():
|
async def test_read_write2():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
v = bytes([0x11, 0x22, 0x33, 0x44])
|
v = bytes([0x11, 0x22, 0x33, 0x44])
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
@@ -466,7 +471,7 @@ async def test_read_write2():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscribe_notify():
|
async def test_subscribe_notify():
|
||||||
[client, server] = TwoDevices().devices
|
[client, server] = LinkedDevices().devices[:2]
|
||||||
|
|
||||||
characteristic1 = Characteristic(
|
characteristic1 = Characteristic(
|
||||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||||
@@ -631,12 +636,49 @@ async def test_subscribe_notify():
|
|||||||
assert not c3._called_2
|
assert not c3._called_2
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mtu_exchange():
|
||||||
|
[d1, d2, d3] = LinkedDevices().devices[:3]
|
||||||
|
|
||||||
|
d3.gatt_server.max_mtu = 100
|
||||||
|
|
||||||
|
d3_connections = []
|
||||||
|
@d3.on('connection')
|
||||||
|
def on_d3_connection(connection):
|
||||||
|
d3_connections.append(connection)
|
||||||
|
|
||||||
|
await d1.power_on()
|
||||||
|
await d2.power_on()
|
||||||
|
await d3.power_on()
|
||||||
|
|
||||||
|
d1_connection = await d1.connect(d3.random_address)
|
||||||
|
assert len(d3_connections) == 1
|
||||||
|
assert d3_connections[0] is not None
|
||||||
|
|
||||||
|
d2_connection = await d2.connect(d3.random_address)
|
||||||
|
assert len(d3_connections) == 2
|
||||||
|
assert d3_connections[1] is not None
|
||||||
|
|
||||||
|
d1_peer = Peer(d1_connection)
|
||||||
|
d2_peer = Peer(d2_connection)
|
||||||
|
|
||||||
|
d1_client_mtu = await d1_peer.request_mtu(220)
|
||||||
|
assert d1_client_mtu == 100
|
||||||
|
assert d1_connection.att_mtu == 100
|
||||||
|
|
||||||
|
d2_client_mtu = await d2_peer.request_mtu(50)
|
||||||
|
assert d2_client_mtu == 50
|
||||||
|
assert d2_connection.att_mtu == 50
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def async_main():
|
async def async_main():
|
||||||
await test_read_write()
|
await test_read_write()
|
||||||
await test_read_write2()
|
await test_read_write2()
|
||||||
await test_subscribe_notify()
|
await test_subscribe_notify()
|
||||||
await test_characteristic_encoding()
|
await test_characteristic_encoding()
|
||||||
|
await test_mtu_exchange()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user