forked from auracaster/bumble_mirror
Merge pull request #75 from google/uael/fixes
Pairing: device/host fixes & improvements
This commit is contained in:
128
bumble/device.py
128
bumble/device.py
@@ -413,12 +413,36 @@ class Connection(CompositeEventEmitter):
|
||||
self.parameters = parameters
|
||||
self.encryption = 0
|
||||
self.authenticated = False
|
||||
self.sc = False
|
||||
self.link_key_type = None
|
||||
self.authenticating = False
|
||||
self.phy = phy
|
||||
self.att_mtu = ATT_DEFAULT_MTU
|
||||
self.data_length = DEVICE_DEFAULT_DATA_LENGTH
|
||||
self.gatt_client = None # Per-connection client
|
||||
self.gatt_server = device.gatt_server # By default, use the device's shared server
|
||||
|
||||
# [Classic only]
|
||||
@classmethod
|
||||
def incomplete(cls, device, peer_address):
|
||||
"""
|
||||
Instantiate an incomplete connection (ie. one waiting for a HCI Connection Complete event).
|
||||
Once received it shall be completed using the `.complete` method.
|
||||
"""
|
||||
return cls(device, None, BT_BR_EDR_TRANSPORT, device.public_address, peer_address, None, None, None, None)
|
||||
|
||||
# [Classic only]
|
||||
def complete(self, handle, peer_resolvable_address, role, parameters):
|
||||
"""
|
||||
Finish an incomplete connection upon completion.
|
||||
"""
|
||||
assert self.handle is None
|
||||
assert self.transport == BT_BR_EDR_TRANSPORT
|
||||
self.handle = handle
|
||||
self.peer_resolvable_address = peer_resolvable_address
|
||||
self.role = role
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
def role_name(self):
|
||||
return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL'
|
||||
@@ -600,6 +624,8 @@ def with_connection_from_handle(function):
|
||||
def with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
if (connection := self.pending_connections.get(address, False)):
|
||||
return function(self, connection, *args, **kwargs)
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, *args, **kwargs)
|
||||
@@ -611,6 +637,8 @@ def with_connection_from_address(function):
|
||||
def try_with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
if (connection := self.pending_connections.get(address, False)):
|
||||
return function(self, connection, address, *args, **kwargs)
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, address, *args, **kwargs)
|
||||
@@ -698,6 +726,7 @@ class Device(CompositeEventEmitter):
|
||||
self.le_connecting = False
|
||||
self.disconnecting = False
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.pending_connections = {} # Connections, by BD address (BR/EDR only)
|
||||
self.classic_enabled = False
|
||||
self.inquiry_response = None
|
||||
self.address_resolver = None
|
||||
@@ -818,7 +847,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False):
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address.get_bytes() == bd_addr.get_bytes():
|
||||
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
|
||||
if check_address_type and connection.peer_address.address_type != bd_addr.address_type:
|
||||
continue
|
||||
if transport is None or connection.transport == transport:
|
||||
@@ -1345,6 +1374,9 @@ class Device(CompositeEventEmitter):
|
||||
max_ce_length = int(prefs.max_ce_length / 0.625),
|
||||
))
|
||||
else:
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
|
||||
|
||||
# TODO: allow passing other settings
|
||||
result = await self.send_command(HCI_Create_Connection_Command(
|
||||
bd_addr = peer_address,
|
||||
@@ -1382,6 +1414,8 @@ class Device(CompositeEventEmitter):
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.le_connecting = False
|
||||
self.connect_own_address_type = None
|
||||
else:
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
@@ -1395,7 +1429,7 @@ class Device(CompositeEventEmitter):
|
||||
Notes:
|
||||
* A `connect` to the same peer will also complete this call.
|
||||
* The `timeout` parameter is only handled while waiting for the connection request,
|
||||
once received and accepeted, the controller shall issue a connection complete event.
|
||||
once received and accepted, the controller shall issue a connection complete event.
|
||||
'''
|
||||
|
||||
if type(peer_address) is str:
|
||||
@@ -1451,6 +1485,9 @@ class Device(CompositeEventEmitter):
|
||||
self.on('connection', on_connection)
|
||||
self.on('connection_failure', on_connection_failure)
|
||||
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
|
||||
|
||||
try:
|
||||
# Accept connection request
|
||||
await self.send_command(HCI_Accept_Connection_Request_Command(
|
||||
@@ -1464,6 +1501,7 @@ class Device(CompositeEventEmitter):
|
||||
finally:
|
||||
self.remove_listener('connection', on_connection)
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_as_gatt(self, peer_address):
|
||||
@@ -1707,9 +1745,13 @@ class Device(CompositeEventEmitter):
|
||||
logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}')
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
# Save in connection we are trying to authenticate
|
||||
connection.authenticating = True
|
||||
|
||||
# Wait for the authentication to complete
|
||||
await pending_authentication
|
||||
finally:
|
||||
connection.authenticating = False
|
||||
connection.remove_listener('connection_authentication', on_authentication)
|
||||
connection.remove_listener('connection_authentication_failure', on_authentication_failure)
|
||||
|
||||
@@ -1786,28 +1828,18 @@ class Device(CompositeEventEmitter):
|
||||
# Set up event handlers
|
||||
pending_name = asyncio.get_running_loop().create_future()
|
||||
|
||||
if type(remote) == Address:
|
||||
peer_address = remote
|
||||
handler = self.on(
|
||||
'remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == remote else None
|
||||
)
|
||||
failure_handler = self.on(
|
||||
'remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
|
||||
)
|
||||
else:
|
||||
peer_address = remote.peer_address
|
||||
handler = remote.on(
|
||||
'remote_name',
|
||||
lambda: pending_name.set_result(remote.peer_name)
|
||||
)
|
||||
failure_handler = remote.on(
|
||||
'remote_name_failure',
|
||||
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
|
||||
)
|
||||
peer_address = remote if type(remote) == Address else remote.peer_address
|
||||
|
||||
handler = self.on(
|
||||
'remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == peer_address else None
|
||||
)
|
||||
failure_handler = self.on(
|
||||
'remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == peer_address else None
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.send_command(
|
||||
@@ -1826,12 +1858,8 @@ class Device(CompositeEventEmitter):
|
||||
# Wait for the result
|
||||
return await pending_name
|
||||
finally:
|
||||
if type(remote) == Address:
|
||||
self.remove_listener('remote_name', handler)
|
||||
self.remove_listener('remote_name_failure', failure_handler)
|
||||
else:
|
||||
remote.remove_listener('remote_name', handler)
|
||||
remote.remove_listener('remote_name_failure', failure_handler)
|
||||
self.remove_listener('remote_name', handler)
|
||||
self.remove_listener('remote_name_failure', failure_handler)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -1849,6 +1877,9 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
asyncio.create_task(store_keys())
|
||||
|
||||
if (connection := self.find_connection_by_bd_addr(bd_addr, transport=BT_BR_EDR_TRANSPORT)):
|
||||
connection.link_key_type = key_type
|
||||
|
||||
def add_service(self, service):
|
||||
self.gatt_server.add_service(service)
|
||||
|
||||
@@ -1875,17 +1906,8 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
if transport == BT_BR_EDR_TRANSPORT:
|
||||
# Create a new connection
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
transport,
|
||||
self.public_address,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
connection_parameters,
|
||||
phy=None
|
||||
)
|
||||
connection: Connection = self.pending_connections.pop(peer_address)
|
||||
connection.complete(connection_handle, peer_resolvable_address, role, connection_parameters)
|
||||
self.connections[connection_handle] = connection
|
||||
|
||||
# We may have an accept ongoing waiting for a connection request for `peer_address`.
|
||||
@@ -1991,6 +2013,9 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# device configuration is set to accept any incoming connection
|
||||
elif self.classic_accept_any:
|
||||
# Save pending connection
|
||||
self.pending_connections[bd_addr] = Connection.incomplete(self, bd_addr)
|
||||
|
||||
self.host.send_command_sync(
|
||||
HCI_Accept_Connection_Request_Command(
|
||||
bd_addr = bd_addr,
|
||||
@@ -2064,6 +2089,17 @@ class Device(CompositeEventEmitter):
|
||||
logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}')
|
||||
connection.emit('connection_authentication_failure', error)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_ssp_complete(self, connection):
|
||||
# On Secure Simple Pairing complete, in case:
|
||||
# - Connection isn't already authenticated
|
||||
# - AND We are not the initiator of the authentication
|
||||
# We must trigger authentication to known if we are truly authenticated
|
||||
if not connection.authenticating and not connection.authenticated:
|
||||
logger.debug(f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}')
|
||||
asyncio.create_task(connection.authenticate())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
@@ -2200,8 +2236,7 @@ class Device(CompositeEventEmitter):
|
||||
if connection:
|
||||
connection.peer_name = remote_name
|
||||
connection.emit('remote_name')
|
||||
else:
|
||||
self.emit('remote_name', address, remote_name)
|
||||
self.emit('remote_name', address, remote_name)
|
||||
except UnicodeDecodeError as error:
|
||||
logger.warning('peer name is not valid UTF-8')
|
||||
if connection:
|
||||
@@ -2215,8 +2250,7 @@ class Device(CompositeEventEmitter):
|
||||
def on_remote_name_failure(self, connection, address, error):
|
||||
if connection:
|
||||
connection.emit('remote_name_failure', error)
|
||||
else:
|
||||
self.emit('remote_name_failure', address, error)
|
||||
self.emit('remote_name_failure', address, error)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
@@ -2282,7 +2316,9 @@ class Device(CompositeEventEmitter):
|
||||
connection.emit('pairing_start')
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing(self, connection, keys):
|
||||
def on_pairing(self, connection, keys, sc):
|
||||
connection.sc = sc
|
||||
connection.authenticated = True
|
||||
connection.emit('pairing', keys)
|
||||
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -599,6 +599,9 @@ class Host(EventEmitter):
|
||||
|
||||
def on_hci_simple_pairing_complete_event(self, event):
|
||||
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
|
||||
# Notify the client
|
||||
if event.status == HCI_SUCCESS:
|
||||
self.emit('ssp_complete', event.bd_addr)
|
||||
|
||||
def on_hci_pin_code_request_event(self, event):
|
||||
# For now, just refuse all requests
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -143,6 +144,10 @@ class KeyStore:
|
||||
async def get_all(self):
|
||||
return []
|
||||
|
||||
async def delete_all(self):
|
||||
all_keys = await self.get_all()
|
||||
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
|
||||
|
||||
async def get_resolving_keys(self):
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
@@ -259,6 +264,13 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
|
||||
|
||||
async def delete_all(self):
|
||||
db = await self.load()
|
||||
|
||||
db.pop(self.namespace, None)
|
||||
|
||||
await self.save(db)
|
||||
|
||||
async def get(self, name):
|
||||
db = await self.load()
|
||||
|
||||
|
||||
@@ -1583,7 +1583,7 @@ class Manager(EventEmitter):
|
||||
asyncio.create_task(store_keys())
|
||||
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection.handle, keys)
|
||||
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session, reason):
|
||||
self.device.on_pairing_failure(session.connection.handle, reason)
|
||||
|
||||
Reference in New Issue
Block a user