Merge pull request #75 from google/uael/fixes

Pairing: device/host fixes & improvements
This commit is contained in:
Lucas Abel
2022-11-28 21:42:43 -08:00
committed by GitHub
4 changed files with 98 additions and 47 deletions

View File

@@ -413,12 +413,36 @@ class Connection(CompositeEventEmitter):
self.parameters = parameters self.parameters = parameters
self.encryption = 0 self.encryption = 0
self.authenticated = False self.authenticated = False
self.sc = False
self.link_key_type = None
self.authenticating = False
self.phy = phy self.phy = phy
self.att_mtu = ATT_DEFAULT_MTU self.att_mtu = ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = None # Per-connection client self.gatt_client = None # Per-connection client
self.gatt_server = device.gatt_server # By default, use the device's shared server self.gatt_server = device.gatt_server # By default, use the device's shared server
# [Classic only]
@classmethod
def incomplete(cls, device, peer_address):
"""
Instantiate an incomplete connection (ie. one waiting for a HCI Connection Complete event).
Once received it shall be completed using the `.complete` method.
"""
return cls(device, None, BT_BR_EDR_TRANSPORT, device.public_address, peer_address, None, None, None, None)
# [Classic only]
def complete(self, handle, peer_resolvable_address, role, parameters):
"""
Finish an incomplete connection upon completion.
"""
assert self.handle is None
assert self.transport == BT_BR_EDR_TRANSPORT
self.handle = handle
self.peer_resolvable_address = peer_resolvable_address
self.role = role
self.parameters = parameters
@property @property
def role_name(self): def role_name(self):
return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL' return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL'
@@ -600,6 +624,8 @@ def with_connection_from_handle(function):
def with_connection_from_address(function): def with_connection_from_address(function):
@functools.wraps(function) @functools.wraps(function)
def wrapper(self, address, *args, **kwargs): def wrapper(self, address, *args, **kwargs):
if (connection := self.pending_connections.get(address, False)):
return function(self, connection, *args, **kwargs)
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address == address: if connection.peer_address == address:
return function(self, connection, *args, **kwargs) return function(self, connection, *args, **kwargs)
@@ -611,6 +637,8 @@ def with_connection_from_address(function):
def try_with_connection_from_address(function): def try_with_connection_from_address(function):
@functools.wraps(function) @functools.wraps(function)
def wrapper(self, address, *args, **kwargs): def wrapper(self, address, *args, **kwargs):
if (connection := self.pending_connections.get(address, False)):
return function(self, connection, address, *args, **kwargs)
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address == address: if connection.peer_address == address:
return function(self, connection, address, *args, **kwargs) return function(self, connection, address, *args, **kwargs)
@@ -698,6 +726,7 @@ class Device(CompositeEventEmitter):
self.le_connecting = False self.le_connecting = False
self.disconnecting = False self.disconnecting = False
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.pending_connections = {} # Connections, by BD address (BR/EDR only)
self.classic_enabled = False self.classic_enabled = False
self.inquiry_response = None self.inquiry_response = None
self.address_resolver = None self.address_resolver = None
@@ -818,7 +847,7 @@ class Device(CompositeEventEmitter):
def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False): def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False):
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address.get_bytes() == bd_addr.get_bytes(): if connection.peer_address.to_bytes() == bd_addr.to_bytes():
if check_address_type and connection.peer_address.address_type != bd_addr.address_type: if check_address_type and connection.peer_address.address_type != bd_addr.address_type:
continue continue
if transport is None or connection.transport == transport: if transport is None or connection.transport == transport:
@@ -1345,6 +1374,9 @@ class Device(CompositeEventEmitter):
max_ce_length = int(prefs.max_ce_length / 0.625), max_ce_length = int(prefs.max_ce_length / 0.625),
)) ))
else: else:
# Save pending connection
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
# TODO: allow passing other settings # TODO: allow passing other settings
result = await self.send_command(HCI_Create_Connection_Command( result = await self.send_command(HCI_Create_Connection_Command(
bd_addr = peer_address, bd_addr = peer_address,
@@ -1382,6 +1414,8 @@ class Device(CompositeEventEmitter):
if transport == BT_LE_TRANSPORT: if transport == BT_LE_TRANSPORT:
self.le_connecting = False self.le_connecting = False
self.connect_own_address_type = None self.connect_own_address_type = None
else:
self.pending_connections.pop(peer_address, None)
async def accept( async def accept(
self, self,
@@ -1395,7 +1429,7 @@ class Device(CompositeEventEmitter):
Notes: Notes:
* A `connect` to the same peer will also complete this call. * A `connect` to the same peer will also complete this call.
* The `timeout` parameter is only handled while waiting for the connection request, * The `timeout` parameter is only handled while waiting for the connection request,
once received and accepeted, the controller shall issue a connection complete event. once received and accepted, the controller shall issue a connection complete event.
''' '''
if type(peer_address) is str: if type(peer_address) is str:
@@ -1451,6 +1485,9 @@ class Device(CompositeEventEmitter):
self.on('connection', on_connection) self.on('connection', on_connection)
self.on('connection_failure', on_connection_failure) self.on('connection_failure', on_connection_failure)
# Save pending connection
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
try: try:
# Accept connection request # Accept connection request
await self.send_command(HCI_Accept_Connection_Request_Command( await self.send_command(HCI_Accept_Connection_Request_Command(
@@ -1464,6 +1501,7 @@ class Device(CompositeEventEmitter):
finally: finally:
self.remove_listener('connection', on_connection) self.remove_listener('connection', on_connection)
self.remove_listener('connection_failure', on_connection_failure) self.remove_listener('connection_failure', on_connection_failure)
self.pending_connections.pop(peer_address, None)
@asynccontextmanager @asynccontextmanager
async def connect_as_gatt(self, peer_address): async def connect_as_gatt(self, peer_address):
@@ -1707,9 +1745,13 @@ class Device(CompositeEventEmitter):
logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}') logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}')
raise HCI_StatusError(result) raise HCI_StatusError(result)
# Save in connection we are trying to authenticate
connection.authenticating = True
# Wait for the authentication to complete # Wait for the authentication to complete
await pending_authentication await pending_authentication
finally: finally:
connection.authenticating = False
connection.remove_listener('connection_authentication', on_authentication) connection.remove_listener('connection_authentication', on_authentication)
connection.remove_listener('connection_authentication_failure', on_authentication_failure) connection.remove_listener('connection_authentication_failure', on_authentication_failure)
@@ -1786,28 +1828,18 @@ class Device(CompositeEventEmitter):
# Set up event handlers # Set up event handlers
pending_name = asyncio.get_running_loop().create_future() pending_name = asyncio.get_running_loop().create_future()
if type(remote) == Address: peer_address = remote if type(remote) == Address else remote.peer_address
peer_address = remote
handler = self.on( handler = self.on(
'remote_name', 'remote_name',
lambda address, remote_name: lambda address, remote_name:
pending_name.set_result(remote_name) if address == remote else None pending_name.set_result(remote_name) if address == peer_address else None
) )
failure_handler = self.on( failure_handler = self.on(
'remote_name_failure', 'remote_name_failure',
lambda address, error_code: lambda address, error_code:
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None pending_name.set_exception(HCI_Error(error_code)) if address == peer_address else None
) )
else:
peer_address = remote.peer_address
handler = remote.on(
'remote_name',
lambda: pending_name.set_result(remote.peer_name)
)
failure_handler = remote.on(
'remote_name_failure',
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
)
try: try:
result = await self.send_command( result = await self.send_command(
@@ -1826,12 +1858,8 @@ class Device(CompositeEventEmitter):
# Wait for the result # Wait for the result
return await pending_name return await pending_name
finally: finally:
if type(remote) == Address: self.remove_listener('remote_name', handler)
self.remove_listener('remote_name', handler) self.remove_listener('remote_name_failure', failure_handler)
self.remove_listener('remote_name_failure', failure_handler)
else:
remote.remove_listener('remote_name', handler)
remote.remove_listener('remote_name_failure', failure_handler)
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@@ -1849,6 +1877,9 @@ class Device(CompositeEventEmitter):
asyncio.create_task(store_keys()) asyncio.create_task(store_keys())
if (connection := self.find_connection_by_bd_addr(bd_addr, transport=BT_BR_EDR_TRANSPORT)):
connection.link_key_type = key_type
def add_service(self, service): def add_service(self, service):
self.gatt_server.add_service(service) self.gatt_server.add_service(service)
@@ -1875,17 +1906,8 @@ class Device(CompositeEventEmitter):
if transport == BT_BR_EDR_TRANSPORT: if transport == BT_BR_EDR_TRANSPORT:
# Create a new connection # Create a new connection
connection = Connection( connection: Connection = self.pending_connections.pop(peer_address)
self, connection.complete(connection_handle, peer_resolvable_address, role, connection_parameters)
connection_handle,
transport,
self.public_address,
peer_address,
peer_resolvable_address,
role,
connection_parameters,
phy=None
)
self.connections[connection_handle] = connection self.connections[connection_handle] = connection
# We may have an accept ongoing waiting for a connection request for `peer_address`. # We may have an accept ongoing waiting for a connection request for `peer_address`.
@@ -1991,6 +2013,9 @@ class Device(CompositeEventEmitter):
# device configuration is set to accept any incoming connection # device configuration is set to accept any incoming connection
elif self.classic_accept_any: elif self.classic_accept_any:
# Save pending connection
self.pending_connections[bd_addr] = Connection.incomplete(self, bd_addr)
self.host.send_command_sync( self.host.send_command_sync(
HCI_Accept_Connection_Request_Command( HCI_Accept_Connection_Request_Command(
bd_addr = bd_addr, bd_addr = bd_addr,
@@ -2064,6 +2089,17 @@ class Device(CompositeEventEmitter):
logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}')
connection.emit('connection_authentication_failure', error) connection.emit('connection_authentication_failure', error)
@host_event_handler
@with_connection_from_address
def on_ssp_complete(self, connection):
# On Secure Simple Pairing complete, in case:
# - Connection isn't already authenticated
# - AND We are not the initiator of the authentication
# We must trigger authentication to known if we are truly authenticated
if not connection.authenticating and not connection.authenticated:
logger.debug(f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}')
asyncio.create_task(connection.authenticate())
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@with_connection_from_address @with_connection_from_address
@@ -2200,8 +2236,7 @@ class Device(CompositeEventEmitter):
if connection: if connection:
connection.peer_name = remote_name connection.peer_name = remote_name
connection.emit('remote_name') connection.emit('remote_name')
else: self.emit('remote_name', address, remote_name)
self.emit('remote_name', address, remote_name)
except UnicodeDecodeError as error: except UnicodeDecodeError as error:
logger.warning('peer name is not valid UTF-8') logger.warning('peer name is not valid UTF-8')
if connection: if connection:
@@ -2215,8 +2250,7 @@ class Device(CompositeEventEmitter):
def on_remote_name_failure(self, connection, address, error): def on_remote_name_failure(self, connection, address, error):
if connection: if connection:
connection.emit('remote_name_failure', error) connection.emit('remote_name_failure', error)
else: self.emit('remote_name_failure', address, error)
self.emit('remote_name_failure', address, error)
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle
@@ -2282,7 +2316,9 @@ class Device(CompositeEventEmitter):
connection.emit('pairing_start') connection.emit('pairing_start')
@with_connection_from_handle @with_connection_from_handle
def on_pairing(self, connection, keys): def on_pairing(self, connection, keys, sc):
connection.sc = sc
connection.authenticated = True
connection.emit('pairing', keys) connection.emit('pairing', keys)
@with_connection_from_handle @with_connection_from_handle

View File

@@ -599,6 +599,9 @@ class Host(EventEmitter):
def on_hci_simple_pairing_complete_event(self, event): def on_hci_simple_pairing_complete_event(self, event):
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('ssp_complete', event.bd_addr)
def on_hci_pin_code_request_event(self, event): def on_hci_pin_code_request_event(self, event):
# For now, just refuse all requests # For now, just refuse all requests

View File

@@ -20,6 +20,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import os import os
import json import json
@@ -143,6 +144,10 @@ class KeyStore:
async def get_all(self): async def get_all(self):
return [] return []
async def delete_all(self):
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
@@ -259,6 +264,13 @@ class JsonKeyStore(KeyStore):
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()] return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
await self.save(db)
async def get(self, name): async def get(self, name):
db = await self.load() db = await self.load()

View File

@@ -1583,7 +1583,7 @@ class Manager(EventEmitter):
asyncio.create_task(store_keys()) asyncio.create_task(store_keys())
# Notify the device # Notify the device
self.device.on_pairing(session.connection.handle, keys) self.device.on_pairing(session.connection.handle, keys, session.sc)
def on_pairing_failure(self, session, reason): def on_pairing_failure(self, session, reason):
self.device.on_pairing_failure(session.connection.handle, reason) self.device.on_pairing_failure(session.connection.handle, reason)