Merge pull request #68 from google/uael/pairing-improvements

Pairing improvements
This commit is contained in:
Lucas Abel
2022-11-11 21:03:17 -08:00
committed by GitHub
4 changed files with 135 additions and 39 deletions

View File

@@ -394,6 +394,7 @@ class Connection(CompositeEventEmitter):
device,
handle,
transport,
self_address,
peer_address,
peer_resolvable_address,
role,
@@ -404,6 +405,7 @@ class Connection(CompositeEventEmitter):
self.device = device
self.handle = handle
self.transport = transport
self.self_address = self_address
self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only
@@ -699,6 +701,10 @@ class Device(CompositeEventEmitter):
self.address_resolver = None
self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY
# Own address type cache
self.advertising_own_address_type = None
self.connect_own_address_type = None
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
if config is None:
@@ -731,8 +737,7 @@ class Device(CompositeEventEmitter):
self.random_address = address
# Setup SMP
# TODO: allow using a public address
self.smp_manager = smp.Manager(self, self.random_address)
self.smp_manager = smp.Manager(self)
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_CID, self.on_smp_pdu)
self.l2cap_channel_manager.register_fixed_channel(
@@ -928,7 +933,7 @@ class Device(CompositeEventEmitter):
self,
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target=None,
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
own_address_type=OwnAddressType.RANDOM,
auto_restart=False
):
# If we're advertising, stop first
@@ -975,9 +980,10 @@ class Device(CompositeEventEmitter):
advertising_enable = 1
), check_result=True)
self.auto_restart_advertising = auto_restart
self.advertising_type = advertising_type
self.advertising = True
self.advertising_own_address_type = own_address_type
self.auto_restart_advertising = auto_restart
self.advertising_type = advertising_type
self.advertising = True
async def stop_advertising(self):
# Disable advertising
@@ -986,9 +992,10 @@ class Device(CompositeEventEmitter):
advertising_enable = 0
), check_result=True)
self.advertising = False
self.advertising_type = None
self.auto_restart_advertising = False
self.advertising_own_address_type = None
self.advertising = False
self.advertising_type = None
self.auto_restart_advertising = False
@property
def is_advertising(self):
@@ -1000,7 +1007,7 @@ class Device(CompositeEventEmitter):
active=True,
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
own_address_type=OwnAddressType.RANDOM,
filter_duplicates=False,
scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY)
):
@@ -1181,7 +1188,7 @@ class Device(CompositeEventEmitter):
peer_address,
transport=BT_LE_TRANSPORT,
connection_parameters_preferences=None,
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
own_address_type=OwnAddressType.RANDOM,
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT
):
'''
@@ -1251,6 +1258,8 @@ class Device(CompositeEventEmitter):
HCI_LE_CODED_PHY: ConnectionParametersPreferences.default
}
self.connect_own_address_type = own_address_type
if self.host.supports_command(HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND):
# Only keep supported PHYs
phys = sorted(list(set(filter(self.supports_le_phy, connection_parameters_preferences.keys()))))
@@ -1350,6 +1359,7 @@ class Device(CompositeEventEmitter):
self.remove_listener('connection_failure', on_connection_failure)
if transport == BT_LE_TRANSPORT:
self.le_connecting = False
self.connect_own_address_type = None
async def accept(
self,
@@ -1847,6 +1857,7 @@ class Device(CompositeEventEmitter):
self,
connection_handle,
transport,
self.public_address,
peer_address,
peer_resolvable_address,
role,
@@ -1875,8 +1886,17 @@ class Device(CompositeEventEmitter):
peer_resolvable_address = peer_address
peer_address = resolved_address
# Guess which own address type is used for this connection.
# This logic is somewhat correct but may need to be improved
# when multiple advertising are run simultaneously.
if self.connect_own_address_type is not None:
own_address_type = self.connect_own_address_type
else:
own_address_type = self.advertising_own_address_type
# We are no longer advertising
self.advertising = False
self.advertising_own_address_type = None
self.advertising = False
# Create and notify of the new connection asynchronously
async def new_connection():
@@ -1890,11 +1910,16 @@ class Device(CompositeEventEmitter):
else:
phy = ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY)
self_address = self.random_address
if own_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
self_address = self.public_address
# Create a new connection
connection = Connection(
self,
connection_handle,
transport,
self_address,
peer_address,
peer_resolvable_address,
role,
@@ -1914,7 +1939,8 @@ class Device(CompositeEventEmitter):
# For directed advertising, this means a timeout
if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed:
self.advertising = False
self.advertising_own_address_type = None
self.advertising = False
# Notify listeners
error = ConnectionError(
@@ -2067,13 +2093,13 @@ class Device(CompositeEventEmitter):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
can_confirm = pairing_config.delegate.io_capability not in {
can_compare = pairing_config.delegate.io_capability not in {
smp.SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
smp.SMP_DISPLAY_ONLY_IO_CAPABILITY
}
# Respond
if can_confirm and pairing_config.delegate:
if can_compare:
async def compare_numbers():
numbers_match = await pairing_config.delegate.compare_numbers(code, digits=6)
if numbers_match:
@@ -2087,9 +2113,18 @@ class Device(CompositeEventEmitter):
asyncio.create_task(compare_numbers())
else:
self.host.send_command_sync(
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
)
async def confirm():
confirm = await pairing_config.delegate.confirm()
if confirm:
self.host.send_command_sync(
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
)
else:
self.host.send_command_sync(
HCI_User_Confirmation_Request_Negative_Reply_Command(bd_addr=connection.peer_address)
)
asyncio.create_task(confirm())
# [Classic only]
@host_event_handler
@@ -2104,7 +2139,7 @@ class Device(CompositeEventEmitter):
}
# Respond
if can_input and pairing_config.delegate:
if can_input:
async def get_number():
number = await pairing_config.delegate.get_number()
if number is not None:
@@ -2124,6 +2159,15 @@ class Device(CompositeEventEmitter):
HCI_User_Passkey_Request_Negative_Reply_Command(bd_addr=connection.peer_address)
)
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_notification(self, connection, passkey):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
asyncio.create_task(pairing_config.delegate.display_number(passkey))
# [Classic only]
@host_event_handler
@try_with_connection_from_address

View File

@@ -1756,6 +1756,26 @@ class Address:
return ':'.join([f'{x:02X}' for x in reversed(self.address_bytes)])
# -----------------------------------------------------------------------------
class OwnAddressType:
PUBLIC = 0
RANDOM = 1
RESOLVABLE_OR_PUBLIC = 2
RESOLVABLE_OR_RANDOM = 3
TYPE_NAMES = {
PUBLIC: 'PUBLIC',
RANDOM: 'RANDOM',
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM'
}
@staticmethod
def type_name(type):
return name_or_number(OwnAddressType.TYPE_NAMES, type)
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
# -----------------------------------------------------------------------------
class HCI_Packet:
'''
@@ -2848,7 +2868,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
('advertising_interval_min', 2),
('advertising_interval_max', 2),
('advertising_type', {'size': 1, 'mapper': lambda x: HCI_LE_Set_Advertising_Parameters_Command.advertising_type_name(x)}),
('own_address_type', Address.ADDRESS_TYPE_SPEC),
('own_address_type', OwnAddressType.TYPE_SPEC),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('advertising_channel_map', 1),
@@ -2927,7 +2947,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
('le_scan_type', 1),
('le_scan_interval', 2),
('le_scan_window', 2),
('own_address_type', Address.ADDRESS_TYPE_SPEC),
('own_address_type', OwnAddressType.TYPE_SPEC),
('scanning_filter_policy', 1)
])
class HCI_LE_Set_Scan_Parameters_Command(HCI_Command):
@@ -2961,7 +2981,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
('initiator_filter_policy', 1),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('own_address_type', Address.ADDRESS_TYPE_SPEC),
('own_address_type', OwnAddressType.TYPE_SPEC),
('connection_interval_min', 2),
('connection_interval_max', 2),
('max_latency', 2),
@@ -3283,7 +3303,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
('primary_advertising_interval_min', 3),
('primary_advertising_interval_max', 3),
('primary_advertising_channel_map', {'size': 1, 'mapper': lambda x: HCI_LE_Set_Extended_Advertising_Parameters_Command.channel_map_string(x)}),
('own_address_type', Address.ADDRESS_TYPE_SPEC),
('own_address_type', OwnAddressType.TYPE_SPEC),
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type),
('advertising_filter_policy', 1),
@@ -3687,7 +3707,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
initiating_phys_strs = bit_flags_to_strings(self.initiating_phys, HCI_LE_PHY_BIT_NAMES)
fields = [
('initiator_filter_policy:', self.initiator_filter_policy),
('own_address_type: ', Address.address_type_name(self.own_address_type)),
('own_address_type: ', OwnAddressType.type_name(self.own_address_type)),
('peer_address_type: ', Address.address_type_name(self.peer_address_type)),
('peer_address: ', str(self.peer_address)),
('initiating_phys: ', ','.join(initiating_phys_strs)),
@@ -4855,6 +4875,17 @@ class HCI_Link_Supervision_Timeout_Changed_Event(HCI_Event):
'''
# -----------------------------------------------------------------------------
@HCI_Event.event([
('bd_addr', Address.parse_address),
('passkey', 4)
])
class HCI_User_Passkey_Notification_Event(HCI_Event):
'''
See Bluetooth spec @ 7.7.48 User Passkey Notification Event
'''
# -----------------------------------------------------------------------------
@HCI_Event.event([
('bd_addr', Address.parse_address),

View File

@@ -638,6 +638,9 @@ class Host(EventEmitter):
def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
self.emit('authentication_user_passkey_notification', event.bd_addr, event.passkey)
def on_hci_inquiry_complete_event(self, event):
self.emit('inquiry_complete')

View File

@@ -477,6 +477,9 @@ class PairingDelegate:
async def accept(self):
return True
async def confirm(self):
return True
async def compare_numbers(self, number, digits=6):
return True
@@ -637,15 +640,16 @@ class Session:
self.oob = False
# Set up addresses
self_address = connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address
if self.is_initiator:
self.ia = bytes(manager.address)
self.iat = 1 if manager.address.is_random else 0
self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0
self.ra = bytes(peer_address)
self.rat = 1 if peer_address.is_random else 0
else:
self.ra = bytes(manager.address)
self.rat = 1 if manager.address.is_random else 0
self.ra = bytes(self_address)
self.rat = 1 if self_address.is_random else 0
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
@@ -715,6 +719,21 @@ class Session:
return False
return True
def prompt_user_for_confirmation(self, next_steps):
async def prompt():
logger.debug('ask for confirmation')
try:
response = await self.pairing_config.delegate.confirm()
if response:
next_steps()
return
except Exception as error:
logger.warn(f'exception while confirm: {error}')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
asyncio.create_task(prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt():
logger.debug(f'verification code: {code}')
@@ -907,8 +926,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
addr_type = self.connection.self_address.address_type,
bd_addr = self.connection.self_address
))
# Distribute CSRK
@@ -939,8 +958,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
addr_type = self.connection.self_address.address_type,
bd_addr = self.connection.self_address
))
# Distribute CSRK
@@ -1387,12 +1406,12 @@ class Session:
# Compute the 6-digit code
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
if self.pairing_method == self.NUMERIC_COMPARISON:
# Ask for user confirmation
self.wait_before_continuing = asyncio.get_running_loop().create_future()
self.prompt_user_for_numeric_comparison(code, next_steps)
# Ask for user confirmation
self.wait_before_continuing = asyncio.get_running_loop().create_future()
if self.pairing_method == self.JUST_WORKS:
self.prompt_user_for_confirmation(next_steps)
else:
next_steps()
self.prompt_user_for_numeric_comparison(code, next_steps)
else:
next_steps()
@@ -1486,10 +1505,9 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol
'''
def __init__(self, device, address):
def __init__(self, device):
super().__init__()
self.device = device
self.address = address
self.sessions = {}
self._ecc_key = None
self.pairing_config_factory = lambda connection: PairingConfig()