diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 4d0492d6..4cbf07e7 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -163,7 +163,6 @@ class ProfileServiceProxy: class Client: def __init__(self, connection): self.connection = connection - self.mtu = ATT_DEFAULT_MTU self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1) self.pending_request = None @@ -217,7 +216,7 @@ class Client: # We can only send one request per connection if self.mtu_exchange_done: - return + return self.connection.att_mtu # Send the request self.mtu_exchange_done = True @@ -230,8 +229,10 @@ class Client: response ) - self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu) - return self.mtu + # Compute the final MTU + self.connection.att_mtu = min(mtu, response.server_rx_mtu) + + return self.connection.att_mtu def get_services_by_uuid(self, uuid): return [service for service in self.services if service.uuid == uuid] @@ -629,7 +630,7 @@ class Client: # If the value is the max size for the MTU, try to read more unless the caller # specifically asked not to do that attribute_value = response.attribute_value - if not no_long_read and len(attribute_value) == self.mtu - 1: + if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1: logger.debug('using READ BLOB to get the rest of the value') offset = len(attribute_value) while True: @@ -651,7 +652,7 @@ class Client: part = response.part_attribute_value attribute_value += part - if len(part) < self.mtu - 1: + if len(part) < self.connection.att_mtu - 1: break offset += len(part) diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 9adf1037..656df6b7 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -40,6 +40,12 @@ from .gatt import * logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +GATT_SERVER_DEFAULT_MAX_MTU = 517 + + # ----------------------------------------------------------------------------- # GATT Server # ----------------------------------------------------------------------------- @@ -49,9 +55,8 @@ class Server(EventEmitter): self.device = device self.attributes = [] # Attributes, ordered by increasing handle values self.attributes_by_handle = {} # Map for fast attribute access by handle - self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate + self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate self.subscribers = {} # Map of subscriber states by connection handle and attribute handle - self.mtus = {} # Map of ATT MTU values by connection handle self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.pending_confirmations = defaultdict(lambda: None) @@ -188,9 +193,8 @@ class Server(EventEmitter): value = attribute.read_value(connection) if value is None else attribute.encode_value(value) # Truncate if needed - mtu = self.get_mtu(connection) - if len(value) > mtu - 3: - value = value[:mtu - 3] + if len(value) > connection.att_mtu - 3: + value = value[:connection.att_mtu - 3] # Notify notification = ATT_Handle_Value_Notification( @@ -219,9 +223,8 @@ class Server(EventEmitter): value = attribute.read_value(connection) if value is None else attribute.encode_value(value) # Truncate if needed - mtu = self.get_mtu(connection) - if len(value) > mtu - 3: - value = value[:mtu - 3] + if len(value) > connection.att_mtu - 3: + value = value[:connection.att_mtu - 3] # Indicate indication = ATT_Handle_Value_Indication( @@ -272,8 +275,6 @@ class Server(EventEmitter): return await self.notify_or_indicate_subscribers(True, attribute, value, force) def on_disconnection(self, connection): - if connection.handle in self.mtus: - del self.mtus[connection.handle] if connection.handle in self.subscribers: del self.subscribers[connection.handle] if connection.handle in self.indication_semaphores: @@ -314,9 +315,6 @@ class Server(EventEmitter): # Just ignore logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}') - def get_mtu(self, connection): - return self.mtus.get(connection.handle, ATT_DEFAULT_MTU) - ####################################################### # ATT handlers ####################################################### @@ -336,12 +334,16 @@ class Server(EventEmitter): ''' See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request ''' - mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu)) - self.mtus[connection.handle] = mtu - self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu)) + self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu)) - # Notify the device - self.device.on_connection_att_mtu_update(connection.handle, mtu) + # Compute the final MTU + if request.client_rx_mtu >= ATT_DEFAULT_MTU: + mtu = min(self.max_mtu, request.client_rx_mtu) + + # Notify the device + self.device.on_connection_att_mtu_update(connection.handle, mtu) + else: + logger.warning('invalid client_rx_mtu received, MTU not changed') def on_att_find_information_request(self, connection, request): ''' @@ -358,7 +360,7 @@ class Server(EventEmitter): return # Build list of returned attributes - pdu_space_available = self.get_mtu(connection) - 2 + pdu_space_available = connection.att_mtu - 2 attributes = [] uuid_size = 0 for attribute in ( @@ -409,7 +411,7 @@ class Server(EventEmitter): ''' # Build list of returned attributes - pdu_space_available = self.get_mtu(connection) - 2 + pdu_space_available = connection.att_mtu - 2 attributes = [] for attribute in ( attribute for attribute in self.attributes if @@ -457,8 +459,7 @@ class Server(EventEmitter): See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request ''' - mtu = self.get_mtu(connection) - pdu_space_available = mtu - 2 + pdu_space_available = connection.att_mtu - 2 attributes = [] for attribute in ( attribute for attribute in self.attributes if @@ -471,7 +472,7 @@ class Server(EventEmitter): # Check the attribute value size attribute_value = attribute.read_value(connection) - max_attribute_size = min(mtu - 4, 253) + max_attribute_size = min(connection.att_mtu - 4, 253) if len(attribute_value) > max_attribute_size: # We need to truncate attribute_value = attribute_value[:max_attribute_size] @@ -511,7 +512,7 @@ class Server(EventEmitter): if attribute := self.get_attribute(request.attribute_handle): # TODO: check permissions value = attribute.read_value(connection) - value_size = min(self.get_mtu(connection) - 1, len(value)) + value_size = min(connection.att_mtu - 1, len(value)) response = ATT_Read_Response( attribute_value = value[:value_size] ) @@ -530,7 +531,6 @@ class Server(EventEmitter): if attribute := self.get_attribute(request.attribute_handle): # TODO: check permissions - mtu = self.get_mtu(connection) value = attribute.read_value(connection) if request.value_offset > len(value): response = ATT_Error_Response( @@ -538,14 +538,14 @@ class Server(EventEmitter): attribute_handle_in_error = request.attribute_handle, error_code = ATT_INVALID_OFFSET_ERROR ) - elif len(value) <= mtu - 1: + elif len(value) <= connection.att_mtu - 1: response = ATT_Error_Response( request_opcode_in_error = request.op_code, attribute_handle_in_error = request.attribute_handle, error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR ) else: - part_size = min(mtu - 1, len(value) - request.value_offset) + part_size = min(connection.att_mtu - 1, len(value) - request.value_offset) response = ATT_Read_Blob_Response( part_attribute_value = value[request.value_offset:request.value_offset + part_size] ) @@ -574,8 +574,7 @@ class Server(EventEmitter): self.send_response(connection, response) return - mtu = self.get_mtu(connection) - pdu_space_available = mtu - 2 + pdu_space_available = connection.att_mtu - 2 attributes = [] for attribute in ( attribute for attribute in self.attributes if @@ -586,7 +585,7 @@ class Server(EventEmitter): ): # Check the attribute value size attribute_value = attribute.read_value(connection) - max_attribute_size = min(mtu - 6, 251) + max_attribute_size = min(connection.att_mtu - 6, 251) if len(attribute_value) > max_attribute_size: # We need to truncate attribute_value = attribute_value[:max_attribute_size] diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 9ff15191..a321c76e 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -131,7 +131,7 @@ async def test_characteristic_encoding(): def decode_value(self, value_bytes): return value_bytes[0] - [client, server] = TwoDevices().devices + [client, server] = LinkedDevices().devices[:2] characteristic = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -306,14 +306,15 @@ def test_CharacteristicValue(): # ----------------------------------------------------------------------------- -class TwoDevices: +class LinkedDevices: def __init__(self): - self.connections = [None, None] + self.connections = [None, None, None] self.link = LocalLink() self.controllers = [ Controller('C1', link = self.link), - Controller('C2', link = self.link) + Controller('C2', link = self.link), + Controller('C3', link = self.link) ] self.devices = [ Device( @@ -321,12 +322,16 @@ class TwoDevices: host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) ), Device( - address = 'F5:F4:F3:F2:F1:F0', + address = 'F1:F2:F3:F4:F5:F6', host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) + ), + Device( + address = 'F2:F3:F4:F5:F6:F7', + host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2])) ) ] - self.paired = [None, None] + self.paired = [None, None, None] # ----------------------------------------------------------------------------- @@ -339,7 +344,7 @@ async def async_barrier(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_write(): - [client, server] = TwoDevices().devices + [client, server] = LinkedDevices().devices[:2] characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -416,7 +421,7 @@ async def test_read_write(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_write2(): - [client, server] = TwoDevices().devices + [client, server] = LinkedDevices().devices[:2] v = bytes([0x11, 0x22, 0x33, 0x44]) characteristic1 = Characteristic( @@ -466,7 +471,7 @@ async def test_read_write2(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_subscribe_notify(): - [client, server] = TwoDevices().devices + [client, server] = LinkedDevices().devices[:2] characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -631,12 +636,49 @@ async def test_subscribe_notify(): assert not c3._called_2 +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_mtu_exchange(): + [d1, d2, d3] = LinkedDevices().devices[:3] + + d3.gatt_server.max_mtu = 100 + + d3_connections = [] + @d3.on('connection') + def on_d3_connection(connection): + d3_connections.append(connection) + + await d1.power_on() + await d2.power_on() + await d3.power_on() + + d1_connection = await d1.connect(d3.random_address) + assert len(d3_connections) == 1 + assert d3_connections[0] is not None + + d2_connection = await d2.connect(d3.random_address) + assert len(d3_connections) == 2 + assert d3_connections[1] is not None + + d1_peer = Peer(d1_connection) + d2_peer = Peer(d2_connection) + + d1_client_mtu = await d1_peer.request_mtu(220) + assert d1_client_mtu == 100 + assert d1_connection.att_mtu == 100 + + d2_client_mtu = await d2_peer.request_mtu(50) + assert d2_client_mtu == 50 + assert d2_connection.att_mtu == 50 + + # ----------------------------------------------------------------------------- async def async_main(): await test_read_write() await test_read_write2() await test_subscribe_notify() await test_characteristic_encoding() + await test_mtu_exchange() # -----------------------------------------------------------------------------