fix role state for classic connections

This commit is contained in:
Gilles Boccon-Gibod
2023-04-07 10:24:26 -07:00
parent c53e1d2480
commit 859aea5a99
8 changed files with 57 additions and 40 deletions

View File

@@ -264,6 +264,7 @@ async def pair(
sc, sc,
mitm, mitm,
bond, bond,
ctkd,
io, io,
prompt, prompt,
request, request,
@@ -317,6 +318,7 @@ async def pair(
if mode == 'classic': if mode == 'classic':
device.classic_enabled = True device.classic_enabled = True
device.le_enabled = False device.le_enabled = False
device.classic_smp_enabled = ctkd
# Get things going # Get things going
await device.power_on() await device.power_on()
@@ -379,6 +381,13 @@ class LogHandler(logging.Handler):
@click.option( @click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True '--bond', type=bool, default=True, help='Enable bonding', show_default=True
) )
@click.option(
'--ctkd',
type=bool,
default=True,
help='Enable CTKD',
show_default=True,
)
@click.option( @click.option(
'--io', '--io',
type=click.Choice( type=click.Choice(
@@ -405,6 +414,7 @@ def main(
sc, sc,
mitm, mitm,
bond, bond,
ctkd,
io, io,
prompt, prompt,
request, request,
@@ -427,6 +437,7 @@ def main(
sc, sc,
mitm, mitm,
bond, bond,
ctkd,
io, io,
prompt, prompt,
request, request,

View File

@@ -595,7 +595,7 @@ class Connection(CompositeEventEmitter):
# [Classic only] # [Classic only]
@classmethod @classmethod
def incomplete(cls, device, peer_address): def incomplete(cls, device, peer_address, role):
""" """
Instantiate an incomplete connection (ie. one waiting for a HCI Connection Instantiate an incomplete connection (ie. one waiting for a HCI Connection
Complete event). Complete event).
@@ -608,28 +608,30 @@ class Connection(CompositeEventEmitter):
device.public_address, device.public_address,
peer_address, peer_address,
None, None,
None, role,
None, None,
None, None,
) )
# [Classic only] # [Classic only]
def complete(self, handle, peer_resolvable_address, role, parameters): def complete(self, handle, parameters):
""" """
Finish an incomplete connection upon completion. Finish an incomplete connection upon completion.
""" """
assert self.handle is None assert self.handle is None
assert self.transport == BT_BR_EDR_TRANSPORT assert self.transport == BT_BR_EDR_TRANSPORT
self.handle = handle self.handle = handle
self.peer_resolvable_address = peer_resolvable_address
# Quirk: role might be known before complete
if self.role is None:
self.role = role
self.parameters = parameters self.parameters = parameters
@property @property
def role_name(self): def role_name(self):
return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL' if self.role is None:
return 'NOT-SET'
if self.role == BT_CENTRAL_ROLE:
return 'CENTRAL'
if self.role == BT_PERIPHERAL_ROLE:
return 'PERIPHERAL'
return f'UNKNOWN[{self.role}]'
@property @property
def is_encrypted(self): def is_encrypted(self):
@@ -637,7 +639,7 @@ class Connection(CompositeEventEmitter):
@property @property
def is_incomplete(self) -> bool: def is_incomplete(self) -> bool:
return self.handle == None return self.handle is None
def send_l2cap_pdu(self, cid, pdu): def send_l2cap_pdu(self, cid, pdu):
self.device.send_l2cap_pdu(self.handle, cid, pdu) self.device.send_l2cap_pdu(self.handle, cid, pdu)
@@ -750,10 +752,11 @@ class DeviceConfiguration:
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.le_enabled = True self.le_enabled = True
# LE host enable 2nd parameter # LE host enable 2nd parameter
self.le_simultaneous_enabled = True self.le_simultaneous_enabled = False
self.classic_enabled = False self.classic_enabled = False
self.classic_sc_enabled = True self.classic_sc_enabled = True
self.classic_ssp_enabled = True self.classic_ssp_enabled = True
self.classic_smp_enabled = True
self.classic_accept_any = True self.classic_accept_any = True
self.connectable = True self.connectable = True
self.discoverable = True self.discoverable = True
@@ -788,6 +791,9 @@ class DeviceConfiguration:
self.classic_ssp_enabled = config.get( self.classic_ssp_enabled = config.get(
'classic_ssp_enabled', self.classic_ssp_enabled 'classic_ssp_enabled', self.classic_ssp_enabled
) )
self.classic_smp_enabled = config.get(
'classic_smp_enabled', self.classic_smp_enabled
)
self.classic_accept_any = config.get( self.classic_accept_any = config.get(
'classic_accept_any', self.classic_accept_any 'classic_accept_any', self.classic_accept_any
) )
@@ -997,8 +1003,9 @@ class Device(CompositeEventEmitter):
self.le_enabled = config.le_enabled self.le_enabled = config.le_enabled
self.classic_enabled = config.classic_enabled self.classic_enabled = config.classic_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_sc_enabled = config.classic_sc_enabled self.classic_sc_enabled = config.classic_sc_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_smp_enabled = config.classic_smp_enabled
self.discoverable = config.discoverable self.discoverable = config.discoverable
self.connectable = config.connectable self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any self.classic_accept_any = config.classic_accept_any
@@ -1043,9 +1050,6 @@ class Device(CompositeEventEmitter):
# Setup SMP # Setup SMP
self.smp_manager = smp.Manager(self) self.smp_manager = smp.Manager(self)
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu) self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_BR_CID, self.on_smp_pdu
)
# Register the SDP server with the L2CAP Channel Manager # Register the SDP server with the L2CAP Channel Manager
self.sdp_server.register(self.l2cap_channel_manager) self.sdp_server.register(self.l2cap_channel_manager)
@@ -1182,6 +1186,12 @@ class Device(CompositeEventEmitter):
if self.keystore is None: if self.keystore is None:
self.keystore = KeyStore.create_for_device(self) self.keystore = KeyStore.create_for_device(self)
# Finish setting up SMP based on post-init configurable options
if self.classic_smp_enabled:
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_BR_CID, self.on_smp_pdu
)
if self.host.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND): if self.host.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND):
await self.send_command( await self.send_command(
HCI_Write_LE_Host_Support_Command( HCI_Write_LE_Host_Support_Command(
@@ -1803,7 +1813,7 @@ class Device(CompositeEventEmitter):
else: else:
# Save pending connection # Save pending connection
self.pending_connections[peer_address] = Connection.incomplete( self.pending_connections[peer_address] = Connection.incomplete(
self, peer_address self, peer_address, BT_CENTRAL_ROLE
) )
# TODO: allow passing other settings # TODO: allow passing other settings
@@ -1942,7 +1952,7 @@ class Device(CompositeEventEmitter):
# Save pending connection # Save pending connection
self.pending_connections[peer_address] = Connection.incomplete( self.pending_connections[peer_address] = Connection.incomplete(
self, peer_address self, peer_address, BT_PERIPHERAL_ROLE
) )
try: try:
@@ -2215,6 +2225,9 @@ class Device(CompositeEventEmitter):
keys = await self.keystore.get(str(address)) keys = await self.keystore.get(str(address))
if keys is not None: if keys is not None:
logger.debug('found keys in the key store') logger.debug('found keys in the key store')
if keys.link_key is None:
logger.debug('no link key')
return None
return keys.link_key.value return keys.link_key.value
# [Classic only] # [Classic only]
@@ -2464,25 +2477,24 @@ class Device(CompositeEventEmitter):
connection_handle, connection_handle,
transport, transport,
peer_address, peer_address,
peer_resolvable_address,
role, role,
connection_parameters, connection_parameters,
): ):
logger.debug( logger.debug(
f'*** Connection: [0x{connection_handle:04X}] ' f'*** Connection: [0x{connection_handle:04X}] '
f'{peer_address} as {HCI_Constant.role_name(role)}' f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}'
) )
if connection_handle in self.connections: if connection_handle in self.connections:
logger.warning( logger.warning(
'new connection reuses the same handle as a previous connection' 'new connection reuses the same handle as a previous connection'
) )
peer_resolvable_address = None
if transport == BT_BR_EDR_TRANSPORT: if transport == BT_BR_EDR_TRANSPORT:
# Create a new connection # Create a new connection
connection = self.pending_connections.pop(peer_address) connection = self.pending_connections.pop(peer_address)
connection.complete( connection.complete(connection_handle, connection_parameters)
connection_handle, peer_resolvable_address, role, connection_parameters
)
self.connections[connection_handle] = connection self.connections[connection_handle] = connection
# Emit an event to notify listeners of the new connection # Emit an event to notify listeners of the new connection
@@ -2594,7 +2606,9 @@ class Device(CompositeEventEmitter):
# device configuration is set to accept any incoming connection # device configuration is set to accept any incoming connection
elif self.classic_accept_any: elif self.classic_accept_any:
# Save pending connection # Save pending connection
self.pending_connections[bd_addr] = Connection.incomplete(self, bd_addr) self.pending_connections[bd_addr] = Connection.incomplete(
self, bd_addr, BT_PERIPHERAL_ROLE
)
self.host.send_command_sync( self.host.send_command_sync(
HCI_Accept_Connection_Request_Command( HCI_Accept_Connection_Request_Command(

View File

@@ -94,10 +94,9 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
def __init__(self, host, handle, role, peer_address, transport): def __init__(self, host, handle, peer_address, transport):
self.host = host self.host = host
self.handle = handle self.handle = handle
self.role = role
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
@@ -534,7 +533,7 @@ class Host(AbortableEventEmitter):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
# Create/update the connection # Create/update the connection
logger.debug( logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] ' f'### LE CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.peer_address} as {HCI_Constant.role_name(event.role)}' f'{event.peer_address} as {HCI_Constant.role_name(event.role)}'
) )
@@ -543,7 +542,6 @@ class Host(AbortableEventEmitter):
connection = Connection( connection = Connection(
self, self,
event.connection_handle, event.connection_handle,
event.role,
event.peer_address, event.peer_address,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
) )
@@ -560,7 +558,6 @@ class Host(AbortableEventEmitter):
event.connection_handle, event.connection_handle,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
event.peer_address, event.peer_address,
None,
event.role, event.role,
connection_parameters, connection_parameters,
) )
@@ -589,7 +586,6 @@ class Host(AbortableEventEmitter):
connection = Connection( connection = Connection(
self, self,
event.connection_handle, event.connection_handle,
BT_CENTRAL_ROLE,
event.bd_addr, event.bd_addr,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
) )
@@ -602,7 +598,6 @@ class Host(AbortableEventEmitter):
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
event.bd_addr, event.bd_addr,
None, None,
BT_CENTRAL_ROLE,
None, None,
) )
else: else:
@@ -622,8 +617,7 @@ class Host(AbortableEventEmitter):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug( logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] ' f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} as ' f'{connection.peer_address} '
f'{HCI_Constant.role_name(connection.role)}, '
f'reason={event.reason}' f'reason={event.reason}'
) )
del self.connections[event.connection_handle] del self.connections[event.connection_handle]
@@ -739,10 +733,6 @@ class Host(AbortableEventEmitter):
f'role change for {event.bd_addr}: ' f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}' f'{HCI_Constant.role_name(event.new_role)}'
) )
if connection := self.find_connection_by_bd_addr(
event.bd_addr, BT_BR_EDR_TRANSPORT
):
connection.role = event.new_role
self.emit('role_change', event.bd_addr, event.new_role) self.emit('role_change', event.bd_addr, event.new_role)
else: else:
logger.debug( logger.debug(

View File

@@ -273,7 +273,7 @@ class JsonKeyStore(KeyStore):
db = await self.load() db = await self.load()
namespace = db.setdefault(self.namespace, {}) namespace = db.setdefault(self.namespace, {})
namespace[name] = keys.to_dict() namespace.setdefault(name, {}).update(keys.to_dict())
await self.save(db) await self.save(db)

View File

@@ -439,7 +439,7 @@ class DLC(EventEmitter):
logger.debug( logger.debug(
f'<<< Credits [{self.dlci}]: ' f'<<< Credits [{self.dlci}]: '
f'received {credits}, total={self.tx_credits}' f'received {received_credits}, total={self.tx_credits}'
) )
data = data[1:] data = data[1:]

View File

@@ -553,7 +553,7 @@ class PairingConfig:
def __init__( def __init__(
self, self,
sc: bool = True, sc: bool = True,
mitm: bool = True, mitm: bool = False,
bonding: bool = True, bonding: bool = True,
delegate: Optional[PairingDelegate] = None, delegate: Optional[PairingDelegate] = None,
) -> None: ) -> None:

View File

@@ -1,4 +1,6 @@
{ {
"name": "Bumble Hands-Free", "name": "Bumble Hands-Free",
"class_of_device": 2360324 "class_of_device": 2360324,
"keystore": "JsonKeyStore",
"le_enabled": false
} }

View File

@@ -447,7 +447,7 @@ async def test_self_smp_wrong_pin():
async def compare_numbers(self, number, digits): async def compare_numbers(self, number, digits):
return False return False
wrong_pin_pairing_config = PairingConfig(delegate=WrongPinDelegate()) wrong_pin_pairing_config = PairingConfig(mitm=True, delegate=WrongPinDelegate())
paired = False paired = False
try: try:
await _test_self_smp_with_configs( await _test_self_smp_with_configs(