Compare commits

..

2 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 368e7eff05 update tests to look at side effects instead of internals 2022-08-12 12:32:36 -07:00
Gilles Boccon-Gibod 55b813bbf5 don't use a lambda as a subscriber 2022-08-12 12:06:08 -07:00
8 changed files with 48 additions and 91 deletions
+2 -2
View File
@@ -29,11 +29,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]" python -m pip install ".[test,development,documentation]"
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest pytest
- name: Build - name: Build
run: | run: |
inv build inv build
inv build.mkdocs inv mkdocs
+5 -13
View File
@@ -315,8 +315,6 @@ class DeviceConfiguration:
self.le_simultaneous_enabled = True self.le_simultaneous_enabled = True
self.classic_sc_enabled = True self.classic_sc_enabled = True
self.classic_ssp_enabled = True self.classic_ssp_enabled = True
self.connectable = True
self.discoverable = True
self.advertising_data = bytes( self.advertising_data = bytes(
AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]) AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))])
) )
@@ -335,8 +333,6 @@ class DeviceConfiguration:
self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled) self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled)
self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled) self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled)
self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled) self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled)
self.connectable = config.get('connectable', self.connectable)
self.discoverable = config.get('discoverable', self.discoverable)
# Load or synthesize an IRK # Load or synthesize an IRK
irk = config.get('irk') irk = config.get('irk')
@@ -450,8 +446,7 @@ class Device(CompositeEventEmitter):
self.command_timeout = 10 # seconds self.command_timeout = 10 # seconds
self.gatt_server = gatt_server.Server(self) self.gatt_server = gatt_server.Server(self)
self.sdp_server = sdp.Server(self) self.sdp_server = sdp.Server(self)
self.l2cap_channel_manager = l2cap.ChannelManager( self.l2cap_channel_manager = l2cap.ChannelManager()
[l2cap.L2CAP_Information_Request.EXTENDED_FEATURE_FIXED_CHANNELS])
self.advertisement_data = {} self.advertisement_data = {}
self.scanning = False self.scanning = False
self.discovering = False self.discovering = False
@@ -459,6 +454,8 @@ class Device(CompositeEventEmitter):
self.disconnecting = False self.disconnecting = False
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.classic_enabled = False self.classic_enabled = False
self.discoverable = False
self.connectable = False
self.inquiry_response = None self.inquiry_response = None
self.address_resolver = None self.address_resolver = None
@@ -479,8 +476,6 @@ class Device(CompositeEventEmitter):
self.le_simultaneous_enabled = config.le_simultaneous_enabled self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_sc_enabled = config.classic_sc_enabled self.classic_sc_enabled = config.classic_sc_enabled
self.discoverable = config.discoverable
self.connectable = config.connectable
# If a name is passed, override the name from the config # If a name is passed, override the name from the config
if name: if name:
@@ -495,8 +490,6 @@ class Device(CompositeEventEmitter):
# Setup SMP # Setup SMP
# TODO: allow using a public address # TODO: allow using a public address
self.smp_manager = smp.Manager(self, self.random_address) self.smp_manager = smp.Manager(self, self.random_address)
self.l2cap_channel_manager.register_fixed_channel(
smp.SMP_CID, self.on_smp_pdu)
# Register the SDP server with the L2CAP Channel Manager # Register the SDP server with the L2CAP Channel Manager
self.sdp_server.register(self.l2cap_channel_manager) self.sdp_server.register(self.l2cap_channel_manager)
@@ -504,7 +497,6 @@ class Device(CompositeEventEmitter):
# Add a GAP Service if requested # Add a GAP Service if requested
if generic_access_service: if generic_access_service:
self.gatt_server.add_service(GenericAccessService(self.name)) self.gatt_server.add_service(GenericAccessService(self.name))
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu)
# Forward some events # Forward some events
setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription') setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription')
@@ -631,8 +623,6 @@ class Device(CompositeEventEmitter):
HCI_Write_Secure_Connections_Host_Support_Command( HCI_Write_Secure_Connections_Host_Support_Command(
secure_connections_host_support=int(self.classic_sc_enabled)) secure_connections_host_support=int(self.classic_sc_enabled))
) )
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
# Let the SMP manager know about the address # Let the SMP manager know about the address
# TODO: allow using a public address # TODO: allow using a public address
@@ -1507,6 +1497,7 @@ class Device(CompositeEventEmitter):
def on_pairing_failure(self, connection, reason): def on_pairing_failure(self, connection, reason):
connection.emit('pairing_failure', reason) connection.emit('pairing_failure', reason)
@host_event_handler
@with_connection_from_handle @with_connection_from_handle
def on_gatt_pdu(self, connection, pdu): def on_gatt_pdu(self, connection, pdu):
# Parse the L2CAP payload into an ATT PDU object # Parse the L2CAP payload into an ATT PDU object
@@ -1525,6 +1516,7 @@ class Device(CompositeEventEmitter):
return return
connection.gatt_server.on_gatt_pdu(connection, att_pdu) connection.gatt_server.on_gatt_pdu(connection, att_pdu)
@host_event_handler
@with_connection_from_handle @with_connection_from_handle
def on_smp_pdu(self, connection, pdu): def on_smp_pdu(self, connection, pdu):
self.smp_manager.on_smp_pdu(connection, pdu) self.smp_manager.on_smp_pdu(connection, pdu)
+13 -1
View File
@@ -56,7 +56,13 @@ class Connection:
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
if l2cap_pdu.cid == ATT_CID:
self.host.on_gatt_pdu(self, l2cap_pdu.payload)
elif l2cap_pdu.cid == SMP_CID:
self.host.on_smp_pdu(self, l2cap_pdu.payload)
else:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -293,6 +299,12 @@ class Host(EventEmitter):
if connection := self.connections.get(packet.connection_handle): if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet) connection.on_hci_acl_data_packet(packet)
def on_gatt_pdu(self, connection, pdu):
self.emit('gatt_pdu', connection.handle, pdu)
def on_smp_pdu(self, connection, pdu):
self.emit('smp_pdu', connection.handle, pdu)
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection, cid, pdu):
self.emit('l2cap_pdu', connection.handle, cid, pdu) self.emit('l2cap_pdu', connection.handle, cid, pdu)
+9 -35
View File
@@ -20,11 +20,11 @@ import logging
import struct import struct
from colors import color from colors import color
from pyee import EventEmitter
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .hci import (HCI_LE_Connection_Update_Command, HCI_Object, key_with_value, from .hci import (HCI_LE_Connection_Update_Command, HCI_Object, key_with_value,
name_or_number) name_or_number)
from .utils import EventEmitter
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -414,18 +414,6 @@ class L2CAP_Information_Request(L2CAP_Control_Frame):
EXTENDED_FEATURES_SUPPORTED = 0x0002 EXTENDED_FEATURES_SUPPORTED = 0x0002
FIXED_CHANNELS_SUPPORTED = 0x0003 FIXED_CHANNELS_SUPPORTED = 0x0003
EXTENDED_FEATURE_FLOW_MODE_CONTROL = 0x0001
EXTENDED_FEATURE_RETRANSMISSION_MODE = 0x0002
EXTENDED_FEATURE_BIDIRECTIONAL_QOS = 0x0004
EXTENDED_FEATURE_ENHANCED_RETRANSMISSION_MODE = 0x0008
EXTENDED_FEATURE_STREAMING_MODE = 0x0010
EXTENDED_FEATURE_FCS_OPTION = 0x0020
EXTENDED_FEATURE_EXTENDED_FLOW_SPEC = 0x0040
EXTENDED_FEATURE_FIXED_CHANNELS = 0x0080
EXTENDED_FEATURE_EXTENDED_WINDOW_SIZE = 0x0100
EXTENDED_FEATURE_UNICAST_CONNECTIONLESS_DATA = 0x0200
EXTENDED_FEATURE_ENHANCED_CREDIT_BASE_FLOW_CONTROL = 0x0400
INFO_TYPE_NAMES = { INFO_TYPE_NAMES = {
CONNECTIONLESS_MTU: 'CONNECTIONLESS_MTU', CONNECTIONLESS_MTU: 'CONNECTIONLESS_MTU',
EXTENDED_FEATURES_SUPPORTED: 'EXTENDED_FEATURES_SUPPORTED', EXTENDED_FEATURES_SUPPORTED: 'EXTENDED_FEATURES_SUPPORTED',
@@ -829,16 +817,11 @@ class Channel(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ChannelManager: class ChannelManager:
def __init__(self, extended_features=None, connectionless_mtu=1024): def __init__(self):
self.host = None self.host = None
self.channels = {} # Channels, mapped by connection and cid self.channels = {} # Channels, mapped by connection and cid
# Fixed channel handlers, mapped by cid self.identifiers = {} # Incrementing identifier values by connection
self.fixed_channels = { self.servers = {} # Servers accepting connections, by PSM
L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None}
self.identifiers = {} # Incrementing identifier values by connection
self.servers = {} # Servers accepting connections, by PSM
self.extended_features = [] if extended_features is None else extended_features
self.connectionless_mtu = connectionless_mtu
def find_channel(self, connection_handle, cid): def find_channel(self, connection_handle, cid):
if connection_channels := self.channels.get(connection_handle): if connection_channels := self.channels.get(connection_handle):
@@ -858,13 +841,6 @@ class ChannelManager:
self.identifiers[connection.handle] = identifier self.identifiers[connection.handle] = identifier
return identifier return identifier
def register_fixed_channel(self, cid, handler):
self.fixed_channels[cid] = handler
def deregister_fixed_channel(self, cid):
if cid in self.fixed_channels:
del self.fixed_channels[cid]
def register_server(self, psm, server): def register_server(self, psm, server):
self.servers[psm] = server self.servers[psm] = server
@@ -879,8 +855,6 @@ class ChannelManager:
control_frame = L2CAP_Control_Frame.from_bytes(pdu) control_frame = L2CAP_Control_Frame.from_bytes(pdu)
self.on_control_frame(connection, cid, control_frame) self.on_control_frame(connection, cid, control_frame)
elif cid in self.fixed_channels:
self.fixed_channels[cid](connection.handle, pdu)
else: else:
if (channel := self.find_channel(connection.handle, cid)) is None: if (channel := self.find_channel(connection.handle, cid)) is None:
logger.warn(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red')) logger.warn(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red'))
@@ -1025,13 +999,13 @@ class ChannelManager:
def on_l2cap_information_request(self, connection, cid, request): def on_l2cap_information_request(self, connection, cid, request):
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU: if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
result = L2CAP_Information_Response.SUCCESS result = L2CAP_Information_Response.SUCCESS
data = self.connectionless_mtu.to_bytes(2, 'little') data = struct.pack('<H', 1024) # TODO: don't use a fixed value
elif request.info_type == L2CAP_Information_Request.EXTENDED_FEATURES_SUPPORTED: elif request.info_type == L2CAP_Information_Request.EXTENDED_FEATURES_SUPPORTED:
result = L2CAP_Information_Response.SUCCESS result = L2CAP_Information_Response.SUCCESS
data = sum(self.extended_features).to_bytes(4, 'little') data = bytes.fromhex('00000000') # TODO: don't use a fixed value
elif request.info_type == L2CAP_Information_Request.FIXED_CHANNELS_SUPPORTED: elif request.info_type == L2CAP_Information_Request.FIXED_CHANNELS_SUPPORTED:
result = L2CAP_Information_Response.SUCCESS result = L2CAP_Information_Response.SUCCESS
data = sum(1 << cid for cid in self.fixed_channels).to_bytes(8, 'little') data = bytes.fromhex('FFFFFFFFFFFFFFFF') # TODO: don't use a fixed value
else: else:
result = L2CAP_Information_Request.NO_SUPPORTED result = L2CAP_Information_Request.NO_SUPPORTED
+1 -2
View File
@@ -17,10 +17,9 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import asyncio import asyncio
from colors import color from colors import color
from pyee import EventEmitter
from .utils import EventEmitter
from .core import InvalidStateError, ProtocolError, ConnectionError from .core import InvalidStateError, ProtocolError, ConnectionError
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+2 -4
View File
@@ -7,12 +7,10 @@ The Console app is an interactive text user interface that offers a number of fu
* scanning * scanning
* advertising * advertising
* connecting to and disconnecting from devices * connecting to devices
* changing connection parameters * changing connection parameters
* enabling encryption
* discovering GATT services and characteristics * discovering GATT services and characteristics
* reading and writing GATT characteristics * read & write GATT characteristics
* subscribing to and unsubscribing from GATT characteristics
The console user interface has 3 main panes: The console user interface has 3 main panes:
+1 -2
View File
@@ -57,13 +57,12 @@ console_scripts =
bumble-link-relay = bumble.apps.link_relay.link_relay:main bumble-link-relay = bumble.apps.link_relay.link_relay:main
[options.extras_require] [options.extras_require]
build =
build >= 0.7
test = test =
pytest >= 6.2 pytest >= 6.2
pytest-asyncio >= 0.17 pytest-asyncio >= 0.17
development = development =
invoke >= 1.4 invoke >= 1.4
build >= 0.7
nox >= 2022 nox >= 2022
documentation = documentation =
mkdocs >= 1.2.3 mkdocs >= 1.2.3
+15 -32
View File
@@ -23,52 +23,35 @@ ROOT_DIR = os.path.dirname(os.path.realpath(__file__))
ns = Collection() ns = Collection()
# Building
build_tasks = Collection() build_tasks = Collection()
ns.add_collection(build_tasks, name="build") ns.add_collection(build_tasks, name='build')
@task @task
def build(ctx, install=False): def build(ctx):
if install: ctx.run('python -m build')
ctx.run('python -m pip install .[build]')
ctx.run("python -m build") build_tasks.add_task(build, default=True, name='build')
build_tasks.add_task(build, default=True)
@task
def release_build(ctx):
build(ctx, install=True)
build_tasks.add_task(release_build, name="release")
@task
def mkdocs(ctx):
ctx.run("mkdocs build -f docs/mkdocs/mkdocs.yml")
build_tasks.add_task(mkdocs, name="mkdocs")
# Testing
test_tasks = Collection() test_tasks = Collection()
ns.add_collection(test_tasks, name="test") ns.add_collection(test_tasks, name='test')
@task @task
def test(ctx, filter=None, junit=False, install=False): def test(ctx, filter=None, junit=False):
# Install the package before running the tests
if install:
ctx.run("python -m pip install .[test]")
args = "" args = ""
if junit: if junit:
args += "--junit-xml test-results.xml" args += "--junit-xml test-results.xml"
if filter is not None: if filter is not None:
args += " -k '{}'".format(filter) args += " -k '{}'".format(filter)
ctx.run("python -m pytest {} {}".format(os.path.join(ROOT_DIR, "tests"), args)) ctx.run('python -m pytest {} {}'
.format(os.path.join(ROOT_DIR, "tests"), args))
test_tasks.add_task(test, name='test', default=True)
test_tasks.add_task(test, default=True)
@task @task
def release_test(ctx): def mkdocs(ctx):
test(ctx, install=True) ctx.run('mkdocs build -f docs/mkdocs/mkdocs.yml')
test_tasks.add_task(release_test, name="release")
ns.add_task(mkdocs)