Compare commits

..

5 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod db5e52f1df add support for alternate settings 2022-09-20 22:25:40 -07:00
Gilles Boccon-Gibod d7da5a9379 add support for dynamic discovery of USB endpoints 2022-09-20 16:39:12 -07:00
Gilles Boccon-Gibod daa05b8996 Merge pull request #36 from google/gbg/pairing-with-no-distribution
gbg/pairing with no distribution
2022-09-02 10:17:31 -07:00
Gilles Boccon-Gibod 624e860762 support empty distributions in both directions 2022-08-30 18:50:48 -07:00
Gilles Boccon-Gibod 159cbf7774 support pairing with no key distribution 2022-08-30 18:28:24 -07:00
7 changed files with 135 additions and 155 deletions
+20 -109
View File
@@ -18,8 +18,7 @@
import json
import asyncio
import logging
import secrets
from contextlib import asynccontextmanager, AsyncExitStack
from contextlib import asynccontextmanager, AsyncExitStack
from .hci import *
from .host import Host
@@ -33,8 +32,6 @@ from . import smp
from . import sdp
from . import l2cap
from . import keys
from . import crypto
# -----------------------------------------------------------------------------
# Logging
@@ -54,7 +51,6 @@ DEVICE_DEFAULT_SCAN_RESPONSE_DATA = b''
DEVICE_DEFAULT_DATA_LENGTH = (27, 328, 27, 328)
DEVICE_DEFAULT_SCAN_INTERVAL = 60 # ms
DEVICE_DEFAULT_SCAN_WINDOW = 60 # ms
DEVICE_DEFAULT_LE_RPA_TIMEOUT = 15 * 60 # 15 minutes (in seconds)
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
@@ -173,6 +169,7 @@ class Peer:
async def __aexit__(self, exc_type, exc_value, traceback):
pass
def __str__(self):
return f'{self.connection.peer_address} as {self.connection.role_name}'
@@ -205,22 +202,11 @@ class Connection(CompositeEventEmitter):
def on_connection_encryption_key_refresh(self):
pass
def __init__(
self,
device,
handle,
transport,
local_address,
peer_address,
peer_resolvable_address,
role,
parameters
):
def __init__(self, device, handle, transport, peer_address, peer_resolvable_address, role, parameters):
super().__init__()
self.device = device
self.handle = handle
self.transport = transport
self.local_address = local_address
self.peer_address = peer_address
self.peer_resolvable_address = peer_resolvable_address
self.peer_name = None # Classic only
@@ -311,12 +297,7 @@ class Connection(CompositeEventEmitter):
raise
def __str__(self):
return (
f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'local_address={self.local_address}, '
f'peer_address={self.peer_address})'
)
return f'Connection(handle=0x{self.handle:04X}, role={self.role_name}, address={self.peer_address})'
# -----------------------------------------------------------------------------
@@ -330,10 +311,8 @@ class DeviceConfiguration:
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.le_enabled = True
# LE host enable 2nd parameter
self.le_simultaneous_enabled = True
self.le_privacy_enabled = False
self.le_rpa_timeout = DEVICE_DEFAULT_LE_RPA_TIMEOUT
self.classic_enabled = False
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
self.connectable = True
@@ -341,22 +320,19 @@ class DeviceConfiguration:
self.advertising_data = bytes(
AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))])
)
self.irk = bytes([0xFF] * 16) # This really must be changed for any level of security
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
def load_from_dict(self, config):
# Load simple properties
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get('advertising_interval', self.advertising_interval_min)
self.advertising_interval_max = self.advertising_interval_min
self.keystore = config.get('keystore')
self.le_enabled = config.get('le_enabled', self.le_enabled)
self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled)
self.le_privacy_enabled = config.get('le_privacy_enabled', self.le_privacy_enabled)
self.le_rpa_timeout = config.get('le_rpa_timeout', self.le_rpa_timeout)
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled)
self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled)
self.connectable = config.get('connectable', self.connectable)
@@ -376,10 +352,6 @@ class DeviceConfiguration:
advertising_data = config.get('advertising_data')
if advertising_data:
self.advertising_data = bytes.fromhex(advertising_data)
else:
self.advertising_data = bytes(
AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))])
)
def load_from_file(self, filename):
with open(filename, 'r') as file:
@@ -486,9 +458,9 @@ class Device(CompositeEventEmitter):
self.connecting = False
self.disconnecting = False
self.connections = {} # Connections, by connection handle
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.le_rpa_task = None
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
@@ -496,7 +468,6 @@ class Device(CompositeEventEmitter):
config = DeviceConfiguration()
self.name = config.name
self.random_address = config.address
self.identity_address = config.address
self.class_of_device = config.class_of_device
self.scan_response_data = config.scan_response_data
self.advertising_data = config.advertising_data
@@ -506,9 +477,6 @@ class Device(CompositeEventEmitter):
self.irk = config.irk
self.le_enabled = config.le_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.le_privacy_enabled = config.le_privacy_enabled
self.le_rpa_timeout = config.le_rpa_timeout
self.classic_enabled = config.classic_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_sc_enabled = config.classic_sc_enabled
self.discoverable = config.discoverable
@@ -522,12 +490,11 @@ class Device(CompositeEventEmitter):
if address:
if type(address) is str:
address = Address(address)
self.random_address = address
self.identity_address = address
self.random_address = address
# Setup SMP
# TODO: allow using a public address
self.smp_manager = smp.Manager(self, self.random_address, self.identity_address)
self.smp_manager = smp.Manager(self, self.random_address)
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_CID, self.on_smp_pdu)
self.l2cap_channel_manager.register_fixed_channel(
@@ -624,14 +591,6 @@ class Device(CompositeEventEmitter):
))
if self.le_enabled:
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = self.generate_le_rpa()
logger.info(f'Initial RPA: {self.random_address}')
if self.le_rpa_timeout > 0:
# Start a task to periodically generate a new RPA
self.le_rpa_task = asyncio.create_task(self.run_le_rpa_generation())
# Set the controller address
await self.send_command(HCI_LE_Set_Random_Address_Command(
random_address = self.random_address
@@ -678,48 +637,13 @@ class Device(CompositeEventEmitter):
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
# Let the SMP manager know about the address
# TODO: allow using a public address
self.smp_manager.address = self.random_address
# Done
self.powered_on = True
async def run_le_rpa_generation(self):
while self.le_rpa_timeout != 0:
await asyncio.sleep(self.le_rpa_timeout)
# Check if this is a good time to rotate the address
if self.advertising or self.scanning or self.connecting:
logger.debug('skipping RPA rotation')
continue
random_address = self.generate_le_rpa()
response = await self.send_command(HCI_LE_Set_Random_Address_Command(
random_address = self.random_address
))
if response.return_parameters == HCI_SUCCESS:
logger.info(f'New RPA: {random_address}')
self.random_address = random_address
else:
logger.warning(f'failed to set RPA: {response.return_parameters}')
def generate_le_rpa(self):
# See 1.3.2.2 Private device address generation
# Generate `prand`
while True:
# Generate a 22-bit random number for the random part of `prand`
prand_random = secrets.randbelow(0x400000)
# As least on bit shall be 0 and one bit shall be 1
if prand_random != 0 and prand_random != 0x3FFFFF:
break
prand = prand_random | 0x400000 # The two MSBs are |1|0|
# Generate `hash`
hash = crypto.ah(self.irk, struct.pack('<I', prand)[:3])
# Generate the address from `prand` and `hash`
return Address(hash + struct.pack('<I', prand)[:3], Address.RANDOM_IDENTITY_ADDRESS)
async def start_advertising(self, auto_restart=False):
self.auto_restart_advertising = auto_restart
@@ -751,24 +675,18 @@ class Device(CompositeEventEmitter):
))
# Enable advertising
response = await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
advertising_enable = 1
))
if response.return_parameters != HCI_SUCCESS:
logger.warning(f'HCI_LE_Set_Advertising_Enable_Command failed ({response.return_parameters})')
raise HCI_Error(response.return_parameters)
self.advertising = True
async def stop_advertising(self):
# Disable advertising
if self.advertising:
response = await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
await self.send_command(HCI_LE_Set_Advertising_Enable_Command(
advertising_enable = 0
))
if response.return_parameters != HCI_SUCCESS:
logger.warning(f'HCI_LE_Set_Advertising_Enable_Command failed ({response.return_parameters})')
raise HCI_Error(response.return_parameters)
self.advertising = False
@@ -803,23 +721,17 @@ class Device(CompositeEventEmitter):
))
# Enable scanning
response = await self.send_command(HCI_LE_Set_Scan_Enable_Command(
await self.send_command(HCI_LE_Set_Scan_Enable_Command(
le_scan_enable = 1,
filter_duplicates = 1 if filter_duplicates else 0
))
if response.return_parameters != HCI_SUCCESS:
raise HCI_Error(response.return_parameters)
self.scanning = True
async def stop_scanning(self):
response = await self.send_command(HCI_LE_Set_Scan_Enable_Command(
await self.send_command(HCI_LE_Set_Scan_Enable_Command(
le_scan_enable = 0,
filter_duplicates = 0
))
if response.return_parameters != HCI_SUCCESS:
raise HCI_Error(response.return_parameters)
self.scanning = False
@property
@@ -1330,7 +1242,6 @@ class Device(CompositeEventEmitter):
self,
connection_handle,
transport,
self.public_address if transport == BT_BR_EDR_TRANSPORT else self.random_address,
peer_address,
peer_resolvable_address,
role,
+3 -5
View File
@@ -1375,11 +1375,9 @@ class HCI_Error(ProtocolError):
class HCI_StatusError(ProtocolError):
def __init__(self, response):
super().__init__(
response.status,
error_namespace=HCI_Command.command_name(response.command_opcode),
error_name=HCI_Constant.status_name(response.status)
)
super().__init__(response.status,
error_namespace=HCI_Command.command_name(response.command_opcode),
error_name=HCI_Constant.status_name(response.status))
# -----------------------------------------------------------------------------
+5
View File
@@ -18,6 +18,8 @@
import logging
from colors import color
from bumble.smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
@@ -73,6 +75,9 @@ class PacketTracer:
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
+23 -16
View File
@@ -155,6 +155,7 @@ SMP_CT2_AUTHREQ = 0b00100000
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -638,13 +639,13 @@ class Session:
# Set up addresses
peer_address = connection.peer_resolvable_address or connection.peer_address
if self.is_initiator:
self.ia = bytes(connection.local_address)
self.iat = 1 if connection.local_address.is_random else 0
self.ia = bytes(manager.address)
self.iat = 1 if manager.address.is_random else 0
self.ra = bytes(peer_address)
self.rat = 1 if peer_address.is_random else 0
else:
self.ra = bytes(connection.local_address)
self.rat = 1 if connection.local_address.is_random else 0
self.ra = bytes(manager.address)
self.rat = 1 if manager.address.is_random else 0
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
@@ -906,8 +907,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.identity_address.address_type,
bd_addr = self.manager.identity_address
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
))
# Distribute CSRK
@@ -938,8 +939,8 @@ class Session:
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
)
self.send_command(SMP_Identity_Address_Information_Command(
addr_type = self.manager.identity_address.address_type,
bd_addr = self.manager.identity_address
addr_type = self.manager.address.address_type,
bd_addr = self.manager.address
))
# Distribute CSRK
@@ -980,12 +981,7 @@ class Session:
self.peer_expected_distributions.remove(command_class)
logger.debug(f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}')
if not self.peer_expected_distributions:
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
# Nothing left to expect, we're done
asyncio.create_task(self.on_pairing())
self.on_peer_key_distribution_complete()
else:
logger.warn(color(f'!!! unexpected key distribution command: {command_class.__name__}', 'red'))
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
@@ -1006,12 +1002,23 @@ class Session:
self.connection.remove_listener('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh)
self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self):
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
asyncio.create_task(self.on_pairing())
def on_connection_encryption_change(self):
if self.connection.is_encrypted:
if self.is_responder:
# The responder distributes its keys first, the initiator later
self.distribute_keys()
# If we're not expecting key distributions from the peer, we're done
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self):
# Do as if the connection had just been encrypted
self.on_connection_encryption_change()
@@ -1479,10 +1486,10 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol
'''
def __init__(self, device, address, identity_address):
def __init__(self, device, address):
super().__init__()
self.device = device
self.identity_address = identity_address
self.address = address
self.sessions = {}
self._ecc_key = None
self.pairing_config_factory = lambda connection: PairingConfig()
+82 -22
View File
@@ -56,18 +56,19 @@ async def open_usb_transport(spec):
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_ACL_OUT = 0x02
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device):
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
@@ -116,7 +117,7 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
USB_ENDPOINT_ACL_OUT,
self.acl_out,
packet[1:],
callback=self.on_packet_sent
)
@@ -152,10 +153,12 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device):
def __init__(self, context, device, acl_in, events_in):
super().__init__()
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
@@ -172,7 +175,7 @@ async def open_usb_transport(spec):
# Set up transfer objects for input
self.events_in_transfer = device.getTransfer()
self.events_in_transfer.setInterrupt(
USB_ENDPOINT_EVENTS_IN,
self.events_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET
@@ -181,7 +184,7 @@ async def open_usb_transport(spec):
self.acl_in_transfer = device.getTransfer()
self.acl_in_transfer.setBulk(
USB_ENDPOINT_ACL_IN,
self.acl_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET
@@ -248,7 +251,7 @@ async def open_usb_transport(spec):
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, source, sink):
def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
@@ -257,6 +260,10 @@ async def open_usb_transport(spec):
# Get exclusive access
device.claimInterface(interface)
# Set the alternate setting if not the default
if setting != 0:
device.setInterfaceAltSetting(interface, setting)
# The source and sink can now start
source.start()
sink.start()
@@ -313,10 +320,63 @@ async def open_usb_transport(spec):
raise ValueError('device not found')
logger.debug(f'USB Device: {found}')
device = found.open()
# Use the first interface
interface = 0
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
setting.getClass() != USB_DEVICE_CLASS_WIRELESS_CONTROLLER or
setting.getSubClass() != USB_DEVICE_SUBCLASS_RF_CONTROLLER or
setting.getProtocol() != USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
):
continue
events_in = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT:
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None:
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in
)
else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}')
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
f'interface={interface}, '
f'setting={setting}, '
f'acl_in=0x{acl_in:02X}, '
f'acl_out=0x{acl_out:02X}, '
f'events_in=0x{events_in:02X}, '
)
device = found.open()
# Detach the kernel driver if supported and needed
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
@@ -329,21 +389,21 @@ async def open_usb_transport(spec):
# Set the configuration if needed
try:
configuration = device.getConfiguration()
logger.debug(f'current configuration = {configuration}')
current_configuration = device.getConfiguration()
logger.debug(f'current configuration = {current_configuration}')
except usb1.USBError:
configuration = 0
current_configuration = 0
if configuration != 1:
if current_configuration != configuration:
try:
logger.debug('setting configuration 1')
device.setConfiguration(1)
logger.debug(f'setting configuration {configuration}')
device.setConfiguration(configuration)
except usb1.USBError:
logger.warning('failed to set configuration 1')
logger.warning('failed to set configuration')
source = UsbPacketSource(context, device)
sink = UsbPacketSink(device)
return UsbTransport(context, device, interface, source, sink)
source = UsbPacketSource(context, device, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
context.close()
+1 -1
View File
@@ -20,7 +20,7 @@ import sys
import os
import logging
from colors import color
from bumble.device import Device
from bumble.device import Device, Peer
from bumble.transport import open_transport
from bumble.profiles.battery_service import BatteryServiceProxy
+1 -2
View File
@@ -246,8 +246,7 @@ IO_CAP = [
SC = [False, True]
MITM = [False, True]
# Key distribution is a 4-bit bitmask
# IdKey is necessary for current SMP structure
KEY_DIST = [i for i in range(16) if (i & SMP_ID_KEY_DISTRIBUTION_FLAG)]
KEY_DIST = range(16)
@pytest.mark.asyncio
@pytest.mark.parametrize('io_cap, sc, mitm, key_dist',