From 55e2f23e29b5e665c8c64f8888f51300413acd21 Mon Sep 17 00:00:00 2001 From: Alan Rosenthal Date: Fri, 9 Dec 2022 12:23:45 -0500 Subject: [PATCH 1/2] Add bumble's version to `show device` --- .gitignore | 2 ++ apps/console.py | 3 +++ bumble/__init__.py | 4 ++++ pyproject.toml | 1 + 4 files changed, 10 insertions(+) diff --git a/.gitignore b/.gitignore index be8e77a8..48d845c6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ dist/ docs/mkdocs/site test-results.xml __pycache__ +# generated by setuptools_scm +bumble/_version.py diff --git a/apps/console.py b/apps/console.py index e0f93b09..e50ea3df 100644 --- a/apps/console.py +++ b/apps/console.py @@ -29,6 +29,7 @@ from collections import OrderedDict import click import colors +from bumble import __version__ from bumble.core import UUID, AdvertisingData, TimeoutError, BT_LE_TRANSPORT from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer from bumble.utils import AsyncRunner @@ -380,6 +381,8 @@ class ConsoleApp: def show_device(self, device): lines = [] + lines.append(('ansicyan', 'Bumble Version: ')) + lines.append(('', f'{__version__}\n')) lines.append(('ansicyan', 'Name: ')) lines.append(('', f'{device.name}\n')) lines.append(('ansicyan', 'Public Address: ')) diff --git a/bumble/__init__.py b/bumble/__init__.py index e69de29b..8a067ca4 100644 --- a/bumble/__init__.py +++ b/bumble/__init__.py @@ -0,0 +1,4 @@ +try: + from ._version import version as __version__ +except ImportError: + __version__ = "unknown version" diff --git a/pyproject.toml b/pyproject.toml index 6eca00c7..dbcc0008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,3 +3,4 @@ requires = ["setuptools>=52", "wheel", "setuptools_scm>=6.2"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] +write_to = "bumble/_version.py" From 135df0dcc01ab765f432e19b1a5202d29bd55545 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Sat, 10 Dec 2022 08:53:51 -0800 Subject: [PATCH 2/2] format with Black --- apps/console.py | 313 +- apps/controller_info.py | 42 +- apps/controllers.py | 13 +- apps/gatt_dump.py | 17 +- apps/gg_bridge.py | 137 +- apps/hci_bridge.py | 37 +- apps/l2cap_bridge.py | 160 +- apps/link_relay/link_relay.py | 34 +- apps/pair.py | 129 +- apps/scan.py | 71 +- apps/show.py | 34 +- apps/unbond.py | 2 +- apps/usb_probe.py | 92 +- bumble/a2dp.py | 409 +-- bumble/att.py | 304 +- bumble/avdtp.py | 496 ++-- bumble/bridge.py | 14 +- bumble/company_ids.py | 4 +- bumble/controller.py | 409 +-- bumble/core.py | 194 +- bumble/crypto.py | 114 +- bumble/device.py | 1325 +++++---- bumble/gap.py | 16 +- bumble/gatt.py | 115 +- bumble/gatt_client.py | 277 +- bumble/gatt_server.py | 346 ++- bumble/hci.py | 2527 ++++++++++------- bumble/helpers.py | 77 +- bumble/hfp.py | 22 +- bumble/host.py | 346 ++- bumble/keys.py | 46 +- bumble/l2cap.py | 1021 ++++--- bumble/link.py | 119 +- bumble/profiles/asha_service.py | 72 +- bumble/profiles/battery_service.py | 13 +- bumble/profiles/device_information_service.py | 59 +- bumble/profiles/heart_rate_service.py | 93 +- bumble/rfcomm.py | 414 +-- bumble/sdp.py | 376 ++- bumble/smp.py | 722 +++-- bumble/transport/__init__.py | 15 +- bumble/transport/android_emulator.py | 11 +- bumble/transport/common.py | 52 +- .../emulated_bluetooth_packets_pb2.py | 34 +- bumble/transport/emulated_bluetooth_pb2.py | 32 +- .../transport/emulated_bluetooth_pb2_grpc.py | 167 +- .../transport/emulated_bluetooth_vhci_pb2.py | 14 +- .../emulated_bluetooth_vhci_pb2_grpc.py | 59 +- bumble/transport/file.py | 7 +- bumble/transport/hci_socket.py | 33 +- bumble/transport/pty.py | 6 +- bumble/transport/pyusb.py | 72 +- bumble/transport/serial.py | 3 +- bumble/transport/tcp_server.py | 4 +- bumble/transport/udp.py | 7 +- bumble/transport/usb.py | 151 +- bumble/transport/vhci.py | 3 +- bumble/transport/ws_client.py | 2 +- bumble/transport/ws_server.py | 14 +- bumble/utils.py | 36 +- examples/async_runner.py | 3 +- examples/battery_client.py | 6 +- examples/battery_server.py | 24 +- examples/device_information_client.py | 57 +- examples/device_information_server.py | 28 +- examples/heart_rate_client.py | 6 +- examples/heart_rate_server.py | 58 +- examples/keyboard.py | 329 ++- examples/run_a2dp_info.py | 45 +- examples/run_a2dp_sink.py | 57 +- examples/run_a2dp_source.py | 64 +- examples/run_advertiser.py | 7 +- examples/run_asha_sink.py | 135 +- examples/run_classic_connect.py | 43 +- examples/run_classic_discoverable.py | 65 +- examples/run_classic_discovery.py | 21 +- examples/run_connect_and_encrypt.py | 11 +- examples/run_controller.py | 31 +- examples/run_controller_with_scanner.py | 17 +- examples/run_gatt_client.py | 7 +- examples/run_gatt_client_and_server.py | 27 +- examples/run_gatt_server.py | 40 +- examples/run_hfp_gateway.py | 83 +- examples/run_hfp_handsfree.py | 60 +- examples/run_notifier.py | 20 +- examples/run_rfcomm_client.py | 70 +- examples/run_rfcomm_server.py | 49 +- examples/run_scanner.py | 13 +- setup.py | 1 + tasks.py | 9 + tests/a2dp_test.py | 146 +- tests/avdtp_test.py | 24 +- tests/core_test.py | 33 +- tests/device_test.py | 191 +- tests/gatt_test.py | 226 +- tests/hci_test.py | 218 +- tests/import_test.py | 15 +- tests/l2cap_test.py | 67 +- tests/rfcomm_test.py | 4 +- tests/sdp_test.py | 34 +- tests/self_test.py | 157 +- tests/smp_test.py | 162 +- tests/transport_test.py | 12 +- web/scanner.py | 4 +- 104 files changed, 8646 insertions(+), 5766 deletions(-) diff --git a/apps/console.py b/apps/console.py index e0f93b09..927abda1 100644 --- a/apps/console.py +++ b/apps/console.py @@ -60,18 +60,18 @@ from prompt_toolkit.layout import ( FormattedTextControl, FloatContainer, ConditionalContainer, - Dimension + Dimension, ) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -BUMBLE_USER_DIR = os.path.expanduser('~/.bumble') -DEFAULT_RSSI_BAR_WIDTH = 20 +BUMBLE_USER_DIR = os.path.expanduser('~/.bumble') +DEFAULT_RSSI_BAR_WIDTH = 20 DEFAULT_CONNECTION_TIMEOUT = 30.0 -DISPLAY_MIN_RSSI = -100 -DISPLAY_MAX_RSSI = -30 -RSSI_MONITOR_INTERVAL = 5.0 # Seconds +DISPLAY_MIN_RSSI = -100 +DISPLAY_MAX_RSSI = -30 +RSSI_MONITOR_INTERVAL = 5.0 # Seconds # ----------------------------------------------------------------------------- @@ -84,12 +84,11 @@ App = None # Utils # ----------------------------------------------------------------------------- + def le_phy_name(phy_id): - return { - HCI_LE_1M_PHY: '1M', - HCI_LE_2M_PHY: '2M', - HCI_LE_CODED_PHY: 'CODED' - }.get(phy_id, HCI_Constant.le_phy_name(phy_id)) + return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get( + phy_id, HCI_Constant.le_phy_name(phy_id) + ) def rssi_bar(rssi): @@ -124,20 +123,22 @@ def parse_phys(phys): # ----------------------------------------------------------------------------- class ConsoleApp: def __init__(self): - self.known_addresses = set() + self.known_addresses = set() self.known_attributes = [] - self.device = None - self.connected_peer = None - self.top_tab = 'device' - self.monitor_rssi = False - self.connection_rssi = None + self.device = None + self.connected_peer = None + self.top_tab = 'device' + self.monitor_rssi = False + self.connection_rssi = None - style = Style.from_dict({ - 'output-field': 'bg:#000044 #ffffff', - 'input-field': 'bg:#000000 #ffffff', - 'line': '#004400', - 'error': 'fg:ansired' - }) + style = Style.from_dict( + { + 'output-field': 'bg:#000044 #ffffff', + 'input-field': 'bg:#000000 #ffffff', + 'line': '#004400', + 'error': 'fg:ansired', + } + ) class LiveCompleter(Completer): def __init__(self, words): @@ -149,52 +150,37 @@ class ConsoleApp: yield Completion(word, start_position=-len(prefix)) def make_completer(): - return NestedCompleter.from_nested_dict({ - 'scan': { - 'on': None, - 'off': None, - 'clear': None - }, - 'advertise': { - 'on': None, - 'off': None - }, - 'rssi': { - 'on': None, - 'off': None - }, - 'show': { - 'scan': None, - 'services': None, - 'attributes': None, - 'log': None, - 'device': None - }, - 'filter': { - 'address': None, - }, - 'connect': LiveCompleter(self.known_addresses), - 'update-parameters': None, - 'encrypt': None, - 'disconnect': None, - 'discover': { - 'services': None, - 'attributes': None - }, - 'request-mtu': None, - 'read': LiveCompleter(self.known_attributes), - 'write': LiveCompleter(self.known_attributes), - 'subscribe': LiveCompleter(self.known_attributes), - 'unsubscribe': LiveCompleter(self.known_attributes), - 'set-phy': { - '1m': None, - '2m': None, - 'coded': None - }, - 'set-default-phy': None, - 'quit': None, - 'exit': None - }) + return NestedCompleter.from_nested_dict( + { + 'scan': {'on': None, 'off': None, 'clear': None}, + 'advertise': {'on': None, 'off': None}, + 'rssi': {'on': None, 'off': None}, + 'show': { + 'scan': None, + 'services': None, + 'attributes': None, + 'log': None, + 'device': None, + }, + 'filter': { + 'address': None, + }, + 'connect': LiveCompleter(self.known_addresses), + 'update-parameters': None, + 'encrypt': None, + 'disconnect': None, + 'discover': {'services': None, 'attributes': None}, + 'request-mtu': None, + 'read': LiveCompleter(self.known_attributes), + 'write': LiveCompleter(self.known_attributes), + 'subscribe': LiveCompleter(self.known_attributes), + 'unsubscribe': LiveCompleter(self.known_attributes), + 'set-phy': {'1m': None, '2m': None, 'coded': None}, + 'set-default-phy': None, + 'quit': None, + 'exit': None, + } + ) self.input_field = TextArea( height=1, @@ -202,49 +188,55 @@ class ConsoleApp: multiline=False, wrap_lines=False, completer=make_completer(), - history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')) + history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')), ) self.input_field.accept_handler = self.accept_input self.output_height = Dimension(min=7, max=7, weight=1) self.output_lines = [] - self.output = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1))) + self.output = FormattedTextControl( + get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1)) + ) self.output_max_lines = 20 self.scan_results_text = FormattedTextControl() self.services_text = FormattedTextControl() self.attributes_text = FormattedTextControl() self.device_text = FormattedTextControl() - self.log_text = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1))) + self.log_text = FormattedTextControl( + get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1)) + ) self.log_height = Dimension(min=7, weight=4) self.log_max_lines = 100 self.log_lines = [] - container = HSplit([ - ConditionalContainer( - Frame(Window(self.scan_results_text), title='Scan Results'), - filter=Condition(lambda: self.top_tab == 'scan') - ), - ConditionalContainer( - Frame(Window(self.services_text), title='Services'), - filter=Condition(lambda: self.top_tab == 'services') - ), - ConditionalContainer( - Frame(Window(self.attributes_text), title='Attributes'), - filter=Condition(lambda: self.top_tab == 'attributes') - ), - ConditionalContainer( - Frame(Window(self.log_text, height=self.log_height), title='Log'), - filter=Condition(lambda: self.top_tab == 'log') - ), - ConditionalContainer( - Frame(Window(self.device_text), title='Device'), - filter=Condition(lambda: self.top_tab == 'device') - ), - Frame(Window(self.output, height=self.output_height)), - FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'), - self.input_field - ]) + container = HSplit( + [ + ConditionalContainer( + Frame(Window(self.scan_results_text), title='Scan Results'), + filter=Condition(lambda: self.top_tab == 'scan'), + ), + ConditionalContainer( + Frame(Window(self.services_text), title='Services'), + filter=Condition(lambda: self.top_tab == 'services'), + ), + ConditionalContainer( + Frame(Window(self.attributes_text), title='Attributes'), + filter=Condition(lambda: self.top_tab == 'attributes'), + ), + ConditionalContainer( + Frame(Window(self.log_text, height=self.log_height), title='Log'), + filter=Condition(lambda: self.top_tab == 'log'), + ), + ConditionalContainer( + Frame(Window(self.device_text), title='Device'), + filter=Condition(lambda: self.top_tab == 'device'), + ), + Frame(Window(self.output, height=self.output_height)), + FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'), + self.input_field, + ] + ) container = FloatContainer( container, @@ -260,16 +252,14 @@ class ConsoleApp: layout = Layout(container, focused_element=self.input_field) kb = KeyBindings() + @kb.add("c-c") @kb.add("c-q") def _(event): event.app.exit() self.ui = Application( - layout=layout, - style=style, - key_bindings=kb, - full_screen=True + layout=layout, style=style, key_bindings=kb, full_screen=True ) async def run_async(self, device_config, transport): @@ -277,13 +267,19 @@ class ConsoleApp: async with await open_transport_or_link(transport) as (hci_source, hci_sink): if device_config: - self.device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) + self.device = Device.from_config_file_with_hci( + device_config, hci_source, hci_sink + ) else: - random_address = f"{random.randint(192,255):02X}" # address is static random + random_address = ( + f"{random.randint(192,255):02X}" # address is static random + ) for c in random.sample(range(255), 5): random_address += f":{c:02X}" self.append_to_log(f"Setting random address: {random_address}") - self.device = Device.with_hci('Bumble', random_address, hci_source, hci_sink) + self.device = Device.with_hci( + 'Bumble', random_address, hci_source, hci_sink + ) self.device.listener = DeviceListener(self) await self.device.power_on() self.show_device(self.device) @@ -321,7 +317,9 @@ class ConsoleApp: else: phy_state = '' connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}{phy_state}' - encryption_state = 'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED' + encryption_state = ( + 'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED' + ) att_mtu = f'ATT_MTU: {connection.att_mtu}' return [ @@ -333,10 +331,10 @@ class ConsoleApp: ('', ' '), ('ansicyan', f' {att_mtu} '), ('', ' '), - ('ansiyellow', f' {rssi} ') + ('ansiyellow', f' {rssi} '), ] - def show_error(self, title, details = None): + def show_error(self, title, details=None): appended = [('class:error', title)] if details: appended.append(('', f' {details}')) @@ -359,7 +357,9 @@ class ConsoleApp: for characteristic in service.characteristics: lines.append(('ansimagenta', ' ' + str(characteristic) + '\n')) - self.known_attributes.append(f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}') + self.known_attributes.append( + f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}' + ) self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}') self.known_attributes.append(f'#{characteristic.handle:X}') for descriptor in characteristic.descriptors: @@ -418,7 +418,7 @@ class ConsoleApp: def append_to_output(self, line, invalidate=True): if type(line) is str: line = [('', line)] - self.output_lines = self.output_lines[-self.output_max_lines:] + self.output_lines = self.output_lines[-self.output_max_lines :] self.output_lines.append(line) formatted_text = [] for line in self.output_lines: @@ -430,7 +430,7 @@ class ConsoleApp: def append_to_log(self, lines, invalidate=True): self.log_lines.extend(lines.split('\n')) - self.log_lines = self.log_lines[-self.log_max_lines:] + self.log_lines = self.log_lines[-self.log_max_lines :] self.log_text.text = ANSI('\n'.join(self.log_lines)) if invalidate: self.ui.invalidate() @@ -515,7 +515,10 @@ class ConsoleApp: elif params[0] == 'on': if len(params) == 2: if not params[1].startswith("filter="): - self.show_error('invalid syntax', 'expected address filter=key1:value1,key2:value,... available filters: address') + self.show_error( + 'invalid syntax', + 'expected address filter=key1:value1,key2:value,... available filters: address', + ) # regex: (word):(any char except ,) matches = re.findall(r"(\w+):([^,]+)", params[1]) for match in matches: @@ -557,8 +560,7 @@ class ConsoleApp: connection_parameters_preferences = None else: connection_parameters_preferences = { - phy: ConnectionParametersPreferences() - for phy in phys + phy: ConnectionParametersPreferences() for phy in phys } if self.device.is_scanning: @@ -570,7 +572,7 @@ class ConsoleApp: await self.device.connect( params[0], connection_parameters_preferences=connection_parameters_preferences, - timeout=DEFAULT_CONNECTION_TIMEOUT + timeout=DEFAULT_CONNECTION_TIMEOUT, ) self.top_tab = 'services' except TimeoutError: @@ -588,7 +590,10 @@ class ConsoleApp: async def do_update_parameters(self, params): if len(params) != 1 or len(params[0].split('/')) != 3: - self.show_error('invalid syntax', 'expected update-parameters -//') + self.show_error( + 'invalid syntax', + 'expected update-parameters -//', + ) return if not self.connected_peer: @@ -596,14 +601,16 @@ class ConsoleApp: return connection_intervals, max_latency, supervision_timeout = params[0].split('/') - connection_interval_min, connection_interval_max = [int(x) for x in connection_intervals.split('-')] + connection_interval_min, connection_interval_max = [ + int(x) for x in connection_intervals.split('-') + ] max_latency = int(max_latency) supervision_timeout = int(supervision_timeout) await self.connected_peer.connection.update_parameters( connection_interval_min, connection_interval_max, max_latency, - supervision_timeout + supervision_timeout, ) async def do_encrypt(self, params): @@ -639,7 +646,9 @@ class ConsoleApp: return phy = await self.connected_peer.connection.get_phy() - self.append_to_output(f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}') + self.append_to_output( + f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}' + ) async def do_request_mtu(self, params): if len(params) != 1: @@ -721,7 +730,9 @@ class ConsoleApp: return await characteristic.subscribe( - lambda value: self.append_to_output(f"{characteristic} VALUE: 0x{value.hex()}"), + lambda value: self.append_to_output( + f"{characteristic} VALUE: 0x{value.hex()}" + ), ) async def do_unsubscribe(self, params): @@ -742,7 +753,9 @@ class ConsoleApp: async def do_set_phy(self, params): if len(params) != 1: - self.show_error('invalid syntax', 'expected set-phy |/') + self.show_error( + 'invalid syntax', 'expected set-phy |/' + ) return if not self.connected_peer: @@ -756,13 +769,15 @@ class ConsoleApp: rx_phys = tx_phys await self.connected_peer.connection.set_phy( - tx_phys=parse_phys(tx_phys), - rx_phys=parse_phys(rx_phys) + tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys) ) async def do_set_default_phy(self, params): if len(params) != 1: - self.show_error('invalid syntax', 'expected set-default-phy |/') + self.show_error( + 'invalid syntax', + 'expected set-default-phy |/', + ) return if '/' in params[0]: @@ -772,8 +787,7 @@ class ConsoleApp: rx_phys = tx_phys await self.device.set_default_phy( - tx_phys=parse_phys(tx_phys), - rx_phys=parse_phys(rx_phys) + tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys) ) async def do_exit(self, params): @@ -789,6 +803,7 @@ class ConsoleApp: return self.device.listener.address_filter = params[1] + # ----------------------------------------------------------------------------- # Device and Connection Listener # ----------------------------------------------------------------------------- @@ -808,7 +823,9 @@ class DeviceListener(Device.Listener, Connection.Listener): self._address_filter = re.compile(r".*") else: self._address_filter = re.compile(filter_addr) - self.scan_results = OrderedDict(filter(lambda x: self.filter_address_match(x), self.scan_results)) + self.scan_results = OrderedDict( + filter(lambda x: self.filter_address_match(x), self.scan_results) + ) self.app.show_scan_results(self.scan_results) def filter_address_match(self, address): @@ -825,24 +842,36 @@ class DeviceListener(Device.Listener, Connection.Listener): connection.listener = self def on_disconnection(self, reason): - self.app.append_to_output(f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}') + self.app.append_to_output( + f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}' + ) self.app.connected_peer = None self.app.connection_rssi = None def on_connection_parameters_update(self): - self.app.append_to_output(f'connection parameters update: {self.app.connected_peer.connection.parameters}') + self.app.append_to_output( + f'connection parameters update: {self.app.connected_peer.connection.parameters}' + ) def on_connection_phy_update(self): - self.app.append_to_output(f'connection phy update: {self.app.connected_peer.connection.phy}') + self.app.append_to_output( + f'connection phy update: {self.app.connected_peer.connection.phy}' + ) def on_connection_att_mtu_update(self): - self.app.append_to_output(f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}') + self.app.append_to_output( + f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}' + ) def on_connection_encryption_change(self): - self.app.append_to_output(f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}') + self.app.append_to_output( + f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}' + ) def on_connection_data_length_change(self): - self.app.append_to_output(f'connection data length change: {self.app.connected_peer.connection.data_length}') + self.app.append_to_output( + f'connection data length change: {self.app.connected_peer.connection.data_length}' + ) def on_advertisement(self, advertisement): if not self.filter_address_match(str(advertisement.address)): @@ -851,12 +880,18 @@ class DeviceListener(Device.Listener, Connection.Listener): entry_key = f'{advertisement.address}/{advertisement.address.address_type}' entry = self.scan_results.get(entry_key) if entry: - entry.ad_data = advertisement.data - entry.rssi = advertisement.rssi + entry.ad_data = advertisement.data + entry.rssi = advertisement.rssi entry.connectable = advertisement.is_connectable else: self.app.add_known_address(str(advertisement.address)) - self.scan_results[entry_key] = ScanResult(advertisement.address, advertisement.address.address_type, advertisement.data, advertisement.rssi, advertisement.is_connectable) + self.scan_results[entry_key] = ScanResult( + advertisement.address, + advertisement.address.address_type, + advertisement.data, + advertisement.rssi, + advertisement.is_connectable, + ) self.app.show_scan_results(self.scan_results) @@ -866,11 +901,11 @@ class DeviceListener(Device.Listener, Connection.Listener): # ----------------------------------------------------------------------------- class ScanResult: def __init__(self, address, address_type, ad_data, rssi, connectable): - self.address = address + self.address = address self.address_type = address_type - self.ad_data = ad_data - self.rssi = rssi - self.connectable = connectable + self.ad_data = ad_data + self.rssi = rssi + self.connectable = connectable def to_display_string(self): address_type_string = ('P', 'R', 'PI', 'RI')[self.address_type] diff --git a/apps/controller_info.py b/apps/controller_info.py index b65caab0..22a1ce58 100644 --- a/apps/controller_info.py +++ b/apps/controller_info.py @@ -39,7 +39,7 @@ from bumble.hci import ( HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND, HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command, HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND, - HCI_LE_Read_Maximum_Advertising_Data_Length_Command + HCI_LE_Read_Maximum_Advertising_Data_Length_Command, ) from bumble.host import Host from bumble.transport import open_transport_or_link @@ -51,13 +51,18 @@ async def get_classic_info(host): response = await host.send_command(HCI_Read_BD_ADDR_Command()) if response.return_parameters.status == HCI_SUCCESS: print() - print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr) + print( + color('Classic Address:', 'yellow'), response.return_parameters.bd_addr + ) if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): response = await host.send_command(HCI_Read_Local_Name_Command()) if response.return_parameters.status == HCI_SUCCESS: print() - print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name)) + print( + color('Local Name:', 'yellow'), + map_null_terminated_utf8_string(response.return_parameters.local_name), + ) # ----------------------------------------------------------------------------- @@ -65,21 +70,25 @@ async def get_le_info(host): print() if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND): - response = await host.send_command(HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()) + response = await host.send_command( + HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command() + ) if response.return_parameters.status == HCI_SUCCESS: print( color('LE Number Of Supported Advertising Sets:', 'yellow'), response.return_parameters.num_supported_advertising_sets, - '\n' + '\n', ) if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND): - response = await host.send_command(HCI_LE_Read_Maximum_Advertising_Data_Length_Command()) + response = await host.send_command( + HCI_LE_Read_Maximum_Advertising_Data_Length_Command() + ) if response.return_parameters.status == HCI_SUCCESS: print( color('LE Maximum Advertising Data Length:', 'yellow'), response.return_parameters.max_advertising_data_length, - '\n' + '\n', ) if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND): @@ -93,7 +102,7 @@ async def get_le_info(host): f'rx:{response.return_parameters.supported_max_rx_octets}/' f'{response.return_parameters.supported_max_rx_time}' ), - '\n' + '\n', ) print(color('LE Features:', 'yellow')) @@ -112,10 +121,19 @@ async def async_main(transport): # Print version print(color('Version:', 'yellow')) - print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier)) - print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version)) + print( + color(' Manufacturer: ', 'green'), + name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier), + ) + print( + color(' HCI Version: ', 'green'), + name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version), + ) print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion) - print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version)) + print( + color(' LMP Version: ', 'green'), + name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version), + ) print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion) # Get the Classic info @@ -135,7 +153,7 @@ async def async_main(transport): @click.command() @click.argument('transport') def main(transport): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) asyncio.run(async_main(transport)) diff --git a/apps/controllers.py b/apps/controllers.py index 8e5f70ee..db1f2d29 100644 --- a/apps/controllers.py +++ b/apps/controllers.py @@ -28,7 +28,9 @@ from bumble.transport import open_transport_or_link # ----------------------------------------------------------------------------- async def async_main(): if len(sys.argv) != 3: - print('Usage: controllers.py [ ...]') + print( + 'Usage: controllers.py [ ...]' + ) print('example: python controllers.py pty:ble1 pty:ble2') return @@ -41,7 +43,12 @@ async def async_main(): for index, transport_name in enumerate(sys.argv[1:]): transport = await open_transport_or_link(transport_name) transports.append(transport) - controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link) + controller = Controller( + f'C{index}', + host_source=transport.source, + host_sink=transport.sink, + link=link, + ) controllers.append(controller) # Wait until the user interrupts @@ -54,7 +61,7 @@ async def async_main(): # ----------------------------------------------------------------------------- def main(): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(async_main()) diff --git a/apps/gatt_dump.py b/apps/gatt_dump.py index 02c3641c..e415408a 100644 --- a/apps/gatt_dump.py +++ b/apps/gatt_dump.py @@ -64,9 +64,13 @@ async def async_main(device_config, encrypt, transport, address_or_name): # Create a device if device_config: - device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) + device = Device.from_config_file_with_hci( + device_config, hci_source, hci_sink + ) else: - device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) + device = Device.with_hci( + 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink + ) await device.power_on() if address_or_name: @@ -81,7 +85,12 @@ async def async_main(device_config, encrypt, transport, address_or_name): else: # Wait for a connection done = asyncio.get_running_loop().create_future() - device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done))) + device.on( + 'connection', + lambda connection: asyncio.create_task( + dump_gatt_db(Peer(connection), done) + ), + ) await device.start_advertising(auto_restart=True) print(color('### Waiting for connection...', 'blue')) @@ -99,7 +108,7 @@ def main(device_config, encrypt, transport, address_or_name): Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, wait for an incoming connection. """ - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) diff --git a/apps/gg_bridge.py b/apps/gg_bridge.py index ac3df8d7..47e23f04 100644 --- a/apps/gg_bridge.py +++ b/apps/gg_bridge.py @@ -33,10 +33,12 @@ from bumble.hci import HCI_Constant # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8' -GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8' -GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8' -GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8' +GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8' +GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8' +GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8' +GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = ( + 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8' +) GG_PREFERRED_MTU = 256 @@ -44,8 +46,8 @@ GG_PREFERRED_MTU = 256 # ----------------------------------------------------------------------------- class GattlinkL2capEndpoint: def __init__(self): - self.l2cap_channel = None - self.l2cap_packet = b'' + self.l2cap_channel = None + self.l2cap_packet = b'' self.l2cap_packet_size = 0 # Called when an L2CAP SDU has been received @@ -71,12 +73,12 @@ class GattlinkL2capEndpoint: class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener): def __init__(self, device, peer_address): super().__init__() - self.device = device - self.peer_address = peer_address - self.peer = None - self.tx_socket = None - self.rx_characteristic = None - self.tx_characteristic = None + self.device = device + self.peer_address = peer_address + self.peer = None + self.tx_socket = None + self.rx_characteristic = None + self.tx_characteristic = None self.l2cap_psm_characteristic = None device.listener = self @@ -127,7 +129,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener): self.rx_characteristic = characteristic elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID: self.tx_characteristic = characteristic - elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID: + elif ( + characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID + ): self.l2cap_psm_characteristic = characteristic print('RX:', self.rx_characteristic) print('TX:', self.tx_characteristic) @@ -135,7 +139,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener): if self.l2cap_psm_characteristic: # Subscribe to and then read the PSM value - await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received) + await self.peer.subscribe( + self.l2cap_psm_characteristic, self.on_l2cap_psm_received + ) psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic) psm = struct.unpack(' [command-short-circuit-list]') - print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078') + print( + 'Usage: hci_bridge.py [command-short-circuit-list]' + ) + print( + 'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078' + ) return print('>>> connecting to HCI...') - async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink): + async with await transport.open_transport_or_link(sys.argv[1]) as ( + hci_host_source, + hci_host_sink, + ): print('>>> connected') print('>>> connecting to HCI...') - async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink): + async with await transport.open_transport_or_link(sys.argv[2]) as ( + hci_controller_source, + hci_controller_sink, + ): print('>>> connected') command_short_circuits = [] @@ -51,18 +61,23 @@ async def async_main(): for op_code_str in sys.argv[3].split(','): if ':' in op_code_str: ogf, ocf = op_code_str.split(':') - command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))) + command_short_circuits.append( + hci.hci_command_op_code(int(ogf, 16), int(ocf, 16)) + ) else: command_short_circuits.append(int(op_code_str, 16)) def host_to_controller_filter(hci_packet): - if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits: + if ( + hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET + and hci_packet.op_code in command_short_circuits + ): # Respond with a success response logger.debug('short-circuiting packet') response = hci.HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = hci_packet.op_code, - return_parameters = bytes([hci.HCI_SUCCESS]) + num_hci_command_packets=1, + command_opcode=hci_packet.op_code, + return_parameters=bytes([hci.HCI_SUCCESS]), ) # Return a packet with 'respond to sender' set to True return (response.to_bytes(), True) @@ -73,14 +88,14 @@ async def async_main(): hci_controller_source, hci_controller_sink, host_to_controller_filter, - None + None, ) await asyncio.get_running_loop().create_future() # ----------------------------------------------------------------------------- def main(): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(async_main()) diff --git a/apps/l2cap_bridge.py b/apps/l2cap_bridge.py index ba658c21..2d80afd3 100644 --- a/apps/l2cap_bridge.py +++ b/apps/l2cap_bridge.py @@ -38,36 +38,32 @@ class ServerBridge: and waits for a new L2CAP CoC channel to be connected. When the TCP connection is closed by the TCP server, XXXX """ - def __init__( - self, - psm, - max_credits, - mtu, - mps, - tcp_host, - tcp_port - ): - self.psm = psm + + def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port): + self.psm = psm self.max_credits = max_credits - self.mtu = mtu - self.mps = mps - self.tcp_host = tcp_host - self.tcp_port = tcp_port + self.mtu = mtu + self.mps = mps + self.tcp_host = tcp_host + self.tcp_port = tcp_port async def start(self, device): # Listen for incoming L2CAP CoC connections device.register_l2cap_channel_server( - psm = self.psm, - server = self.on_coc, - max_credits = self.max_credits, - mtu = self.mtu, - mps = self.mps + psm=self.psm, + server=self.on_coc, + max_credits=self.max_credits, + mtu=self.mtu, + mps=self.mps, ) print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow')) def on_ble_connection(connection): def on_ble_disconnection(reason): - print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason)) + print( + color('@@@ Bluetooth disconnection:', 'red'), + HCI_Constant.error_name(reason), + ) print(color('@@@ Bluetooth connection:', 'green'), connection) connection.on('disconnection', on_ble_disconnection) @@ -82,7 +78,7 @@ class ServerBridge: class Pipe: def __init__(self, bridge, l2cap_channel): - self.bridge = bridge + self.bridge = bridge self.tcp_transport = None self.l2cap_channel = l2cap_channel @@ -91,7 +87,12 @@ class ServerBridge: async def connect_to_tcp(self): # Connect to the TCP server - print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow')) + print( + color( + f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', + 'yellow', + ) + ) class TcpClientProtocol(asyncio.Protocol): def __init__(self, pipe): @@ -107,7 +108,10 @@ class ServerBridge: self.pipe.l2cap_channel.write(data) try: - self.tcp_transport, _ = await asyncio.get_running_loop().create_connection( + ( + self.tcp_transport, + _, + ) = await asyncio.get_running_loop().create_connection( lambda: TcpClientProtocol(self), host=self.bridge.tcp_host, port=self.bridge.tcp_port, @@ -149,23 +153,14 @@ class ClientBridge: READ_CHUNK_SIZE = 4096 - def __init__( - self, - psm, - max_credits, - mtu, - mps, - address, - tcp_host, - tcp_port - ): - self.psm = psm + def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port): + self.psm = psm self.max_credits = max_credits - self.mtu = mtu - self.mps = mps - self.address = address - self.tcp_host = tcp_host - self.tcp_port = tcp_port + self.mtu = mtu + self.mps = mps + self.address = address + self.tcp_host = tcp_host + self.tcp_port = tcp_port async def start(self, device): print(color(f'### Connecting to {self.address}...', 'yellow')) @@ -174,7 +169,10 @@ class ClientBridge: # Called when the BLE connection is disconnected def on_ble_disconnection(reason): - print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason)) + print( + color('@@@ Bluetooth disconnection:', 'red'), + HCI_Constant.error_name(reason), + ) connection.on('disconnection', on_ble_disconnection) @@ -196,10 +194,10 @@ class ClientBridge: print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow')) try: l2cap_channel = await connection.open_l2cap_channel( - psm = self.psm, - max_credits = self.max_credits, - mtu = self.mtu, - mps = self.mps + psm=self.psm, + max_credits=self.max_credits, + mtu=self.mtu, + mps=self.mps, ) print(color('*** L2CAP channel:', 'cyan'), l2cap_channel) except Exception as error: @@ -215,7 +213,7 @@ class ClientBridge: l2cap_channel.pause_reading, l2cap_channel.resume_reading, writer.write, - writer.drain + writer.drain, ) l2cap_to_tcp_pipe.start() @@ -242,9 +240,13 @@ class ClientBridge: await asyncio.start_server( on_tcp_connection, host=self.tcp_host if self.tcp_host != '_' else None, - port=self.tcp_port + port=self.tcp_port, + ) + print( + color( + f'### Listening for TCP connections on port {self.tcp_port}', 'magenta' + ) ) - print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta')) # ----------------------------------------------------------------------------- @@ -266,20 +268,43 @@ async def run(device_config, hci_transport, bridge): # ----------------------------------------------------------------------------- @click.group() @click.pass_context -@click.option('--device-config', help='Device configuration file', required=True) -@click.option('--hci-transport', help='HCI transport', required=True) -@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234) -@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128) -@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022) -@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024) -def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps): +@click.option('--device-config', help='Device configuration file', required=True) +@click.option('--hci-transport', help='HCI transport', required=True) +@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234) +@click.option( + '--l2cap-coc-max-credits', + help='Maximum L2CAP CoC Credits', + type=click.IntRange(1, 65535), + default=128, +) +@click.option( + '--l2cap-coc-mtu', + help='L2CAP CoC MTU', + type=click.IntRange(23, 65535), + default=1022, +) +@click.option( + '--l2cap-coc-mps', + help='L2CAP CoC MPS', + type=click.IntRange(23, 65533), + default=1024, +) +def cli( + context, + device_config, + hci_transport, + psm, + l2cap_coc_max_credits, + l2cap_coc_mtu, + l2cap_coc_mps, +): context.ensure_object(dict) context.obj['device_config'] = device_config context.obj['hci_transport'] = hci_transport - context.obj['psm'] = psm - context.obj['max_credits'] = l2cap_coc_max_credits - context.obj['mtu'] = l2cap_coc_mtu - context.obj['mps'] = l2cap_coc_mps + context.obj['psm'] = psm + context.obj['max_credits'] = l2cap_coc_max_credits + context.obj['mtu'] = l2cap_coc_mtu + context.obj['mps'] = l2cap_coc_mps # ----------------------------------------------------------------------------- @@ -294,12 +319,9 @@ def server(context, tcp_host, tcp_port): context.obj['mtu'], context.obj['mps'], tcp_host, - tcp_port) - asyncio.run(run( - context.obj['device_config'], - context.obj['hci_transport'], - bridge - )) + tcp_port, + ) + asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge)) # ----------------------------------------------------------------------------- @@ -316,16 +338,12 @@ def client(context, bluetooth_address, tcp_host, tcp_port): context.obj['mps'], bluetooth_address, tcp_host, - tcp_port + tcp_port, ) - asyncio.run(run( - context.obj['device_config'], - context.obj['hci_transport'], - bridge - )) + asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge)) # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) if __name__ == '__main__': cli(obj={}) diff --git a/apps/link_relay/link_relay.py b/apps/link_relay/link_relay.py index c979ea67..89400bc9 100644 --- a/apps/link_relay/link_relay.py +++ b/apps/link_relay/link_relay.py @@ -65,9 +65,9 @@ class Connection: """ def __init__(self, room, websocket): - self.room = room + self.room = room self.websocket = websocket - self.address = str(uuid.uuid4()) + self.address = str(uuid.uuid4()) async def send_message(self, message): try: @@ -110,9 +110,9 @@ class Room: """ def __init__(self, relay, name): - self.relay = relay - self.name = name - self.observers = [] + self.relay = relay + self.name = name + self.observers = [] self.connections = [] async def add_connection(self, connection): @@ -145,7 +145,9 @@ class Room: # This is an RPC request await self.on_rpc_request(connection, message) else: - await connection.send_message(f'result:{error_to_json("error: invalid message")}') + await connection.send_message( + f'result:{error_to_json("error: invalid message")}' + ) async def broadcast_message(self, sender, message): ''' @@ -155,7 +157,9 @@ class Room: async def on_rpc_request(self, connection, message): command, *params = message.split(' ', 1) - if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None): + if handler := getattr( + self, f'on_{command[1:].lower().replace("-","_")}_command', None + ): try: result = await handler(connection, params) except Exception as error: @@ -192,7 +196,9 @@ class Room: current_address = connection.address new_address = params[0] connection.set_address(new_address) - await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}') + await self.broadcast_message( + connection, f'address-changed:from={current_address},to={new_address}' + ) # ---------------------------------------------------------------------------- @@ -246,24 +252,24 @@ def main(): print('ERROR: Python 3.6.1 or higher is required') sys.exit(1) - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) # Parse arguments arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') arg_parser.add_argument('--log-level', default='INFO', help='logger level') arg_parser.add_argument('--log-config', help='logger config file (YAML)') - arg_parser.add_argument('--port', - type = int, - default = DEFAULT_RELAY_PORT, - help = 'Port to listen on') + arg_parser.add_argument( + '--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on' + ) args = arg_parser.parse_args() # Setup logger if args.log_config: from logging import config + config.fileConfig(args.log_config) else: - logging.basicConfig(level = getattr(logging, args.log_level.upper())) + logging.basicConfig(level=getattr(logging, args.log_level.upper())) # Start a relay relay = Relay(args.port) diff --git a/apps/pair.py b/apps/pair.py index 7f896299..69c68671 100644 --- a/apps/pair.py +++ b/apps/pair.py @@ -33,30 +33,32 @@ from bumble.gatt import ( GATT_GENERIC_ACCESS_SERVICE, Service, Characteristic, - CharacteristicValue + CharacteristicValue, ) from bumble.att import ( ATT_Error, ATT_INSUFFICIENT_AUTHENTICATION_ERROR, - ATT_INSUFFICIENT_ENCRYPTION_ERROR + ATT_INSUFFICIENT_ENCRYPTION_ERROR, ) # ----------------------------------------------------------------------------- class Delegate(PairingDelegate): def __init__(self, mode, connection, capability_string, prompt): - super().__init__({ - 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, - 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, - 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, - 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, - 'none': PairingDelegate.NO_OUTPUT_NO_INPUT - }[capability_string.lower()]) + super().__init__( + { + 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, + 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, + 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, + 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, + 'none': PairingDelegate.NO_OUTPUT_NO_INPUT, + }[capability_string.lower()] + ) - self.mode = mode - self.peer = Peer(connection) + self.mode = mode + self.peer = Peer(connection) self.peer_name = None - self.prompt = prompt + self.prompt = prompt async def update_peer_name(self): if self.peer_name is not None: @@ -103,7 +105,11 @@ class Delegate(PairingDelegate): print(color(f'### Pairing with {self.peer_name}', 'yellow')) print(color('###-----------------------------------', 'yellow')) while True: - response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow')) + response = await aioconsole.ainput( + color( + f'>>> Does the other device display {number:0{digits}}? ', 'yellow' + ) + ) response = response.lower().strip() if response == 'yes': return True @@ -149,7 +155,9 @@ async def get_peer_name(peer, mode): if not services: return None - values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0]) + values = await peer.read_characteristics_by_uuid( + GATT_DEVICE_NAME_CHARACTERISTIC, services[0] + ) if values: return values[0].decode('utf-8') @@ -183,14 +191,14 @@ def on_connection(connection, request): print(color(f'<<< Connection: {connection}', 'green')) # Listen for pairing events - connection.on('pairing_start', on_pairing_start) - connection.on('pairing', on_pairing) + connection.on('pairing_start', on_pairing_start) + connection.on('pairing', on_pairing) connection.on('pairing_failure', on_pairing_failure) # Listen for encryption changes connection.on( 'connection_encryption_change', - lambda: on_connection_encryption_change(connection) + lambda: on_connection_encryption_change(connection), ) # Request pairing if needed @@ -202,7 +210,12 @@ def on_connection(connection, request): # ----------------------------------------------------------------------------- def on_connection_encryption_change(connection): print(color('@@@-----------------------------------', 'blue')) - print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue')) + print( + color( + f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', + 'blue', + ) + ) print(color('@@@-----------------------------------', 'blue')) @@ -241,7 +254,7 @@ async def pair( keystore_file, device_config, hci_transport, - address_or_name + address_or_name, ): print('<<< connecting to HCI...') async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): @@ -272,9 +285,11 @@ async def pair( '552957FB-CF1F-4A31-9535-E78847E1A714', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - CharacteristicValue(read=read_with_error, write=write_with_error) + CharacteristicValue( + read=read_with_error, write=write_with_error + ), ) - ] + ], ) ) @@ -288,10 +303,7 @@ async def pair( # Set up a pairing config factory device.pairing_config_factory = lambda connection: PairingConfig( - sc, - mitm, - bond, - Delegate(mode, connection, io, prompt) + sc, mitm, bond, Delegate(mode, connection, io, prompt) ) # Connect to a peer or wait for a connection @@ -319,21 +331,70 @@ async def pair( # ----------------------------------------------------------------------------- @click.command() -@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True) -@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True) -@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True) -@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True) -@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True) +@click.option( + '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True +) +@click.option( + '--sc', + type=bool, + default=True, + help='Use the Secure Connections protocol', + show_default=True, +) +@click.option( + '--mitm', type=bool, default=True, help='Request MITM protection', show_default=True +) +@click.option( + '--bond', type=bool, default=True, help='Enable bonding', show_default=True +) +@click.option( + '--io', + type=click.Choice( + ['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none'] + ), + default='display+keyboard', + show_default=True, +) @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') -@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing') +@click.option( + '--request', is_flag=True, help='Request that the connecting peer initiate pairing' +) @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option('--keystore-file', help='File in which to store the pairing keys') @click.argument('device-config') @click.argument('hci_transport') @click.argument('address-or-name', required=False) -def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) - asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name)) +def main( + mode, + sc, + mitm, + bond, + io, + prompt, + request, + print_keys, + keystore_file, + device_config, + hci_transport, + address_or_name, +): + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run( + pair( + mode, + sc, + mitm, + bond, + io, + prompt, + request, + print_keys, + keystore_file, + device_config, + hci_transport, + address_or_name, + ) + ) # ----------------------------------------------------------------------------- diff --git a/apps/scan.py b/apps/scan.py index d6c10923..1d4c2570 100644 --- a/apps/scan.py +++ b/apps/scan.py @@ -31,8 +31,8 @@ from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY # ----------------------------------------------------------------------------- def make_rssi_bar(rssi): - DISPLAY_MIN_RSSI = -105 - DISPLAY_MAX_RSSI = -30 + DISPLAY_MIN_RSSI = -105 + DISPLAY_MAX_RSSI = -30 DEFAULT_RSSI_BAR_WIDTH = 30 blocks = ['', '▏', '▎', '▍', '▌', '▋', '▊', '▉'] @@ -63,7 +63,9 @@ class AdvertisementPrinter: resolution_qualifier = f'(resolved from {advertisement.address})' address = resolved - address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type] + address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[ + address.address_type + ] if address.is_public: type_color = 'cyan' else: @@ -93,7 +95,8 @@ class AdvertisementPrinter: f'[{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}' f'{phy_info}' f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}' - f'{advertisement.data.to_string(separator)}\n') + f'{advertisement.data.to_string(separator)}\n' + ) def on_advertisement(self, advertisement): self.print_advertisement(advertisement) @@ -114,16 +117,20 @@ async def scan( raw, keystore_file, device_config, - transport + transport, ): print('<<< connecting to HCI...') async with await open_transport_or_link(transport) as (hci_source, hci_sink): print('<<< connected') if device_config: - device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) + device = Device.from_config_file_with_hci( + device_config, hci_source, hci_sink + ) else: - device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) + device = Device.with_hci( + 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink + ) if keystore_file: keystore = JsonKeyStore(namespace=None, filename=keystore_file) @@ -153,7 +160,7 @@ async def scan( scan_interval=scan_interval, scan_window=scan_window, filter_duplicates=filter_duplicates, - scanning_phys=scanning_phys + scanning_phys=scanning_phys, ) await hci_source.wait_for_termination() @@ -165,15 +172,51 @@ async def scan( @click.option('--passive', is_flag=True, default=False, help='Perform passive scanning') @click.option('--scan-interval', type=int, default=60, help='Scan interval') @click.option('--scan-window', type=int, default=60, help='Scan window') -@click.option('--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY') -@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level') -@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones') +@click.option( + '--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY' +) +@click.option( + '--filter-duplicates', + type=bool, + default=True, + help='Filter duplicates at the controller level', +) +@click.option( + '--raw', + is_flag=True, + default=False, + help='Listen for raw advertising reports instead of processed ones', +) @click.option('--keystore-file', help='Keystore file to use when resolving addresses') @click.option('--device-config', help='Device config file for the scanning device') @click.argument('transport') -def main(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) - asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport)) +def main( + min_rssi, + passive, + scan_interval, + scan_window, + phy, + filter_duplicates, + raw, + keystore_file, + device_config, + transport, +): + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + asyncio.run( + scan( + min_rssi, + passive, + scan_interval, + scan_window, + phy, + filter_duplicates, + raw, + keystore_file, + device_config, + transport, + ) + ) # ----------------------------------------------------------------------------- diff --git a/apps/show.py b/apps/show.py index a4efe04b..6329791f 100644 --- a/apps/show.py +++ b/apps/show.py @@ -30,10 +30,10 @@ class SnoopPacketReader: Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...) ''' - DATALINK_H1 = 1001 - DATALINK_H4 = 1002 + DATALINK_H1 = 1001 + DATALINK_H4 = 1002 DATALINK_BSCP = 1003 - DATALINK_H5 = 1004 + DATALINK_H5 = 1004 def __init__(self, source): self.source = source @@ -41,9 +41,16 @@ class SnoopPacketReader: # Read the header identification_pattern = source.read(8) if identification_pattern.hex().lower() != '6274736e6f6f7000': - raise ValueError('not a valid snoop file, unexpected identification pattern') - (self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8)) - if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1: + raise ValueError( + 'not a valid snoop file, unexpected identification pattern' + ) + (self.version_number, self.data_link_type) = struct.unpack( + '>II', source.read(8) + ) + if ( + self.data_link_type != self.DATALINK_H4 + and self.data_link_type != self.DATALINK_H1 + ): raise ValueError(f'datalink type {self.data_link_type} not supported') def next_packet(self): @@ -57,7 +64,7 @@ class SnoopPacketReader: packet_flags, cumulative_drops, timestamp_seconds, - timestamp_microsecond + timestamp_microsecond, ) = struct.unpack('>IIIIII', header) # Abort on truncated packets @@ -79,7 +86,10 @@ class SnoopPacketReader: else: packet_type = hci.HCI_ACL_DATA_PACKET - return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length)) + return ( + packet_flags & 1, + bytes([packet_type]) + self.source.read(included_length), + ) else: return (packet_flags & 1, self.source.read(included_length)) @@ -88,7 +98,12 @@ class SnoopPacketReader: # Main # ----------------------------------------------------------------------------- @click.command() -@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file') +@click.option( + '--format', + type=click.Choice(['h4', 'snoop']), + default='h4', + help='Format of the input file', +) @click.argument('filename') def main(format, filename): input = open(filename, 'rb') @@ -97,6 +112,7 @@ def main(format, filename): def read_next_packet(): (0, packet_reader.next_packet()) + else: packet_reader = SnoopPacketReader(input) read_next_packet = packet_reader.next_packet diff --git a/apps/unbond.py b/apps/unbond.py index cf1877c3..105d9a48 100644 --- a/apps/unbond.py +++ b/apps/unbond.py @@ -54,7 +54,7 @@ async def unbond(keystore_file, device_config, address): @click.argument('device-config') @click.argument('address', required=False) def main(keystore_file, device_config, address): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(unbond(keystore_file, device_config, address)) diff --git a/apps/usb_probe.py b/apps/usb_probe.py index 26b7f409..00a04fa1 100644 --- a/apps/usb_probe.py +++ b/apps/usb_probe.py @@ -37,9 +37,9 @@ from colors import color # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -USB_DEVICE_CLASS_DEVICE = 0x00 -USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 -USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 +USB_DEVICE_CLASS_DEVICE = 0x00 +USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 +USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 USB_DEVICE_CLASSES = { @@ -69,22 +69,22 @@ USB_DEVICE_CLASSES = { 0x01: 'Bluetooth', 0x02: 'UWB', 0x03: 'Remote NDIS', - 0x04: 'Bluetooth AMP' + 0x04: 'Bluetooth AMP', } - } + }, ), 0xEF: 'Miscellaneous', 0xFE: 'Application Specific', - 0xFF: 'Vendor Specific' + 0xFF: 'Vendor Specific', } -USB_ENDPOINT_IN = 0x80 +USB_ENDPOINT_IN = 0x80 USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT'] USB_BT_HCI_CLASS_TUPLE = ( USB_DEVICE_CLASS_WIRELESS_CONTROLLER, USB_DEVICE_SUBCLASS_RF_CONTROLLER, - USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER + USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, ) @@ -95,18 +95,24 @@ def show_device_details(device): for interface in configuration: for setting in interface: alternateSetting = setting.getAlternateSetting() - suffix = f'/{alternateSetting}' if interface.getNumSettings() > 1 else '' + suffix = ( + f'/{alternateSetting}' if interface.getNumSettings() > 1 else '' + ) (class_string, subclass_string) = get_class_info( - setting.getClass(), - setting.getSubClass(), - setting.getProtocol() + setting.getClass(), setting.getSubClass(), setting.getProtocol() ) details = f'({class_string}, {subclass_string})' print(f' Interface: {setting.getNumber()}{suffix} {details}') for endpoint in setting: endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3] - endpoint_direction = 'OUT' if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) else 'IN' - print(f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}') + endpoint_direction = ( + 'OUT' + if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) + else 'IN' + ) + print( + f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}' + ) # ----------------------------------------------------------------------------- @@ -135,7 +141,11 @@ def get_class_info(cls, subclass, protocol): # ----------------------------------------------------------------------------- def is_bluetooth_hci(device): # Check if the device class indicates a match - if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == USB_BT_HCI_CLASS_TUPLE: + if ( + device.getDeviceClass(), + device.getDeviceSubClass(), + device.getDeviceProtocol(), + ) == USB_BT_HCI_CLASS_TUPLE: return True # If the device class is 'Device', look for a matching interface @@ -143,7 +153,11 @@ def is_bluetooth_hci(device): for configuration in device: for interface in configuration: for setting in interface: - if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == USB_BT_HCI_CLASS_TUPLE: + if ( + setting.getClass(), + setting.getSubClass(), + setting.getProtocol(), + ) == USB_BT_HCI_CLASS_TUPLE: return True return False @@ -153,23 +167,21 @@ def is_bluetooth_hci(device): @click.command() @click.option('--verbose', is_flag=True, default=False, help='Print more details') def main(verbose): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) with usb1.USBContext() as context: bluetooth_device_count = 0 devices = {} for device in context.getDeviceIterator(skip_on_error=True): - device_class = device.getDeviceClass() + device_class = device.getDeviceClass() device_subclass = device.getDeviceSubClass() device_protocol = device.getDeviceProtocol() device_id = (device.getVendorID(), device.getProductID()) (device_class_string, device_subclass_string) = get_class_info( - device_class, - device_subclass, - device_protocol + device_class, device_subclass, device_protocol ) try: @@ -198,7 +210,9 @@ def main(verbose): # Compute the different ways this can be referenced as a Bumble transport bumble_transport_names = [] - basic_transport_name = f'usb:{device.getVendorID():04X}:{device.getProductID():04X}' + basic_transport_name = ( + f'usb:{device.getVendorID():04X}:{device.getProductID():04X}' + ) if device_is_bluetooth_hci: bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}') @@ -206,17 +220,39 @@ def main(verbose): if device_id not in devices: bumble_transport_names.append(basic_transport_name) else: - bumble_transport_names.append(f'{basic_transport_name}#{len(devices[device_id])}') + bumble_transport_names.append( + f'{basic_transport_name}#{len(devices[device_id])}' + ) if device_serial_number is not None: - if device_id not in devices or device_serial_number not in devices[device_id]: - bumble_transport_names.append(f'{basic_transport_name}/{device_serial_number}') + if ( + device_id not in devices + or device_serial_number not in devices[device_id] + ): + bumble_transport_names.append( + f'{basic_transport_name}/{device_serial_number}' + ) # Print the results - print(color(f'ID {device.getVendorID():04X}:{device.getProductID():04X}', fg=fg_color, bg=bg_color)) + print( + color( + f'ID {device.getVendorID():04X}:{device.getProductID():04X}', + fg=fg_color, + bg=bg_color, + ) + ) if bumble_transport_names: - print(color(' Bumble Transport Names:', 'blue'), ' or '.join(color(x, 'cyan' if device_is_bluetooth_hci else 'red') for x in bumble_transport_names)) - print(color(' Bus/Device: ', 'green'), f'{device.getBusNumber():03}/{device.getDeviceAddress():03}') + print( + color(' Bumble Transport Names:', 'blue'), + ' or '.join( + color(x, 'cyan' if device_is_bluetooth_hci else 'red') + for x in bumble_transport_names + ), + ) + print( + color(' Bus/Device: ', 'green'), + f'{device.getBusNumber():03}/{device.getDeviceAddress():03}', + ) print(color(' Class: ', 'green'), device_class_string) print(color(' Subclass/Protocol: ', 'green'), device_subclass_string) if device_serial_number is not None: diff --git a/bumble/a2dp.py b/bumble/a2dp.py index 03c3fd20..9e2783e2 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -30,7 +30,7 @@ from .sdp import ( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) from .core import ( BT_L2CAP_PROTOCOL_ID, @@ -38,7 +38,7 @@ from .core import ( BT_AUDIO_SINK_SERVICE, BT_AVDTP_PROTOCOL_ID, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, - name_or_number + name_or_number, ) @@ -51,6 +51,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off A2DP_SBC_CODEC_TYPE = 0x00 A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01 @@ -127,6 +128,8 @@ MPEG_2_4_OBJECT_TYPE_NAMES = { MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE' } +# fmt: on + # ----------------------------------------------------------------------------- def flags_to_list(flags, values): @@ -140,58 +143,98 @@ def flags_to_list(flags, values): # ----------------------------------------------------------------------------- def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)): from .avdtp import AVDTP_PSM + version_int = version[0] << 8 | version[1] return [ - ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), - ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) - ])), - ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(BT_AUDIO_SOURCE_SERVICE) - ])), - ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID), - DataElement.unsigned_integer_16(AVDTP_PSM) - ]), - DataElement.sequence([ - DataElement.uuid(BT_AVDTP_PROTOCOL_ID), - DataElement.unsigned_integer_16(version_int) - ]) - ])), - ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int) - ])), + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(AVDTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVDTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ), + ), ] # ----------------------------------------------------------------------------- def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): from .avdtp import AVDTP_PSM + version_int = version[0] << 8 | version[1] return [ - ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), - ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) - ])), - ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(BT_AUDIO_SINK_SERVICE) - ])), - ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID), - DataElement.unsigned_integer_16(AVDTP_PSM) - ]), - DataElement.sequence([ - DataElement.uuid(BT_AVDTP_PROTOCOL_ID), - DataElement.unsigned_integer_16(version_int) - ]) - ])), - ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int) - ])), + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(AVDTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVDTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ), + ), ] @@ -206,8 +249,8 @@ class SbcMediaCodecInformation( 'subbands', 'allocation_method', 'minimum_bitpool_value', - 'maximum_bitpool_value' - ] + 'maximum_bitpool_value', + ], ) ): ''' @@ -215,36 +258,25 @@ class SbcMediaCodecInformation( ''' BIT_FIELDS = 'u4u4u4u2u2u8u8' - SAMPLING_FREQUENCY_BITS = { - 16000: 1 << 3, - 32000: 1 << 2, - 44100: 1 << 1, - 48000: 1 - } + SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1} CHANNEL_MODE_BITS = { - SBC_MONO_CHANNEL_MODE: 1 << 3, - SBC_DUAL_CHANNEL_MODE: 1 << 2, - SBC_STEREO_CHANNEL_MODE: 1 << 1, - SBC_JOINT_STEREO_CHANNEL_MODE: 1 - } - BLOCK_LENGTH_BITS = { - 4: 1 << 3, - 8: 1 << 2, - 12: 1 << 1, - 16: 1 - } - SUBBANDS_BITS = { - 4: 1 << 1, - 8: 1 + SBC_MONO_CHANNEL_MODE: 1 << 3, + SBC_DUAL_CHANNEL_MODE: 1 << 2, + SBC_STEREO_CHANNEL_MODE: 1 << 1, + SBC_JOINT_STEREO_CHANNEL_MODE: 1, } + BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1} + SUBBANDS_BITS = {4: 1 << 1, 8: 1} ALLOCATION_METHOD_BITS = { - SBC_SNR_ALLOCATION_METHOD: 1 << 1, - SBC_LOUDNESS_ALLOCATION_METHOD: 1 + SBC_SNR_ALLOCATION_METHOD: 1 << 1, + SBC_LOUDNESS_ALLOCATION_METHOD: 1, } @staticmethod def from_bytes(data): - return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)) + return SbcMediaCodecInformation( + *bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data) + ) @classmethod def from_discrete_values( @@ -255,16 +287,16 @@ class SbcMediaCodecInformation( subbands, allocation_method, minimum_bitpool_value, - maximum_bitpool_value + maximum_bitpool_value, ): return SbcMediaCodecInformation( - sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], - channel_mode = cls.CHANNEL_MODE_BITS[channel_mode], - block_length = cls.BLOCK_LENGTH_BITS[block_length], - subbands = cls.SUBBANDS_BITS[subbands], - allocation_method = cls.ALLOCATION_METHOD_BITS[allocation_method], - minimum_bitpool_value = minimum_bitpool_value, - maximum_bitpool_value = maximum_bitpool_value + sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], + channel_mode=cls.CHANNEL_MODE_BITS[channel_mode], + block_length=cls.BLOCK_LENGTH_BITS[block_length], + subbands=cls.SUBBANDS_BITS[subbands], + allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method], + minimum_bitpool_value=minimum_bitpool_value, + maximum_bitpool_value=maximum_bitpool_value, ) @classmethod @@ -276,16 +308,20 @@ class SbcMediaCodecInformation( subbands, allocation_methods, minimum_bitpool_value, - maximum_bitpool_value + maximum_bitpool_value, ): return SbcMediaCodecInformation( - sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), - channel_mode = sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes), - block_length = sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths), - subbands = sum(cls.SUBBANDS_BITS[x] for x in subbands), - allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods), - minimum_bitpool_value = minimum_bitpool_value, - maximum_bitpool_value = maximum_bitpool_value + sampling_frequency=sum( + cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies + ), + channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes), + block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths), + subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands), + allocation_method=sum( + cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods + ), + minimum_bitpool_value=minimum_bitpool_value, + maximum_bitpool_value=maximum_bitpool_value, ) def __bytes__(self): @@ -294,30 +330,25 @@ class SbcMediaCodecInformation( def __str__(self): channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] allocation_methods = ['SNR', 'Loudness'] - return '\n'.join([ - 'SbcMediaCodecInformation(', - f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}', - f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}', - f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}', - f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}', - f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}', - f' minimum_bitpool_value: {self.minimum_bitpool_value}', - f' maximum_bitpool_value: {self.maximum_bitpool_value}' - ')' - ]) + return '\n'.join( + [ + 'SbcMediaCodecInformation(', + f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}', + f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}', + f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}', + f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}', + f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}', + f' minimum_bitpool_value: {self.minimum_bitpool_value}', + f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')', + ] + ) # ----------------------------------------------------------------------------- class AacMediaCodecInformation( namedtuple( 'AacMediaCodecInformation', - [ - 'object_type', - 'sampling_frequency', - 'channels', - 'vbr', - 'bitrate' - ] + ['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'], ) ): ''' @@ -326,13 +357,13 @@ class AacMediaCodecInformation( BIT_FIELDS = 'u8u12u2p2u1u23' OBJECT_TYPE_BITS = { - MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, - MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, - MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5, - MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4 + MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, + MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, + MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5, + MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4, } SAMPLING_FREQUENCY_BITS = { - 8000: 1 << 11, + 8000: 1 << 11, 11025: 1 << 10, 12000: 1 << 9, 16000: 1 << 8, @@ -343,66 +374,65 @@ class AacMediaCodecInformation( 48000: 1 << 3, 64000: 1 << 2, 88200: 1 << 1, - 96000: 1 - } - CHANNELS_BITS = { - 1: 1 << 1, - 2: 1 + 96000: 1, } + CHANNELS_BITS = {1: 1 << 1, 2: 1} @staticmethod def from_bytes(data): - return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)) - - @classmethod - def from_discrete_values( - cls, - object_type, - sampling_frequency, - channels, - vbr, - bitrate - ): return AacMediaCodecInformation( - object_type = cls.OBJECT_TYPE_BITS[object_type], - sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], - channels = cls.CHANNELS_BITS[channels], - vbr = vbr, - bitrate = bitrate + *bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data) ) @classmethod - def from_lists( - cls, - object_types, - sampling_frequencies, - channels, - vbr, - bitrate + def from_discrete_values( + cls, object_type, sampling_frequency, channels, vbr, bitrate ): return AacMediaCodecInformation( - object_type = sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), - sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), - channels = sum(cls.CHANNELS_BITS[x] for x in channels), - vbr = vbr, - bitrate = bitrate + object_type=cls.OBJECT_TYPE_BITS[object_type], + sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], + channels=cls.CHANNELS_BITS[channels], + vbr=vbr, + bitrate=bitrate, + ) + + @classmethod + def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate): + return AacMediaCodecInformation( + object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), + sampling_frequency=sum( + cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies + ), + channels=sum(cls.CHANNELS_BITS[x] for x in channels), + vbr=vbr, + bitrate=bitrate, ) def __bytes__(self): return bitstruct.pack(self.BIT_FIELDS, *self) def __str__(self): - object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]'] + object_types = [ + 'MPEG_2_AAC_LC', + 'MPEG_4_AAC_LC', + 'MPEG_4_AAC_LTP', + 'MPEG_4_AAC_SCALABLE', + '[4]', + '[5]', + '[6]', + '[7]', + ] channels = [1, 2] - return '\n'.join([ - 'AacMediaCodecInformation(', - f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}', - f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}', - f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}', - f' vbr: {self.vbr}', - f' bitrate: {self.bitrate}' - ')' - ]) + return '\n'.join( + [ + 'AacMediaCodecInformation(', + f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}', + f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}', + f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}', + f' vbr: {self.vbr}', + f' bitrate: {self.bitrate}' ')', + ] + ) # ----------------------------------------------------------------------------- @@ -418,37 +448,33 @@ class VendorSpecificMediaCodecInformation: def __init__(self, vendor_id, codec_id, value): self.vendor_id = vendor_id - self.codec_id = codec_id - self.value = value + self.codec_id = codec_id + self.value = value def __bytes__(self): return struct.pack('> 6) & 3] - blocks = 4 * (1 + ((header[1] >> 4) & 3)) - channel_mode = (header[1] >> 2) & 3 - channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2 - subbands = 8 if ((header[1]) & 1) else 4 - bitpool = header[2] + blocks = 4 * (1 + ((header[1] >> 4) & 3)) + channel_mode = (header[1] >> 2) & 3 + channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2 + subbands = 8 if ((header[1]) & 1) else 4 + bitpool = header[2] # Compute the frame length frame_length = 4 + (4 * subbands * channels) // 8 if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE): frame_length += (blocks * channels * bitpool) // 8 else: - frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8 + frame_length += ( + (1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) + * subbands + + blocks * bitpool + ) // 8 # Read the rest of the frame payload = header + await self.read(frame_length - 4) # Emit the next frame - yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload) + yield SbcFrame( + sampling_frequency, blocks, channel_mode, subbands, payload + ) return generate_frames() @@ -512,8 +544,8 @@ class SbcParser: # ----------------------------------------------------------------------------- class SbcPacketSource: def __init__(self, read, mtu, codec_capabilities): - self.read = read - self.mtu = mtu + self.read = read + self.mtu = mtu self.codec_capabilities = codec_capabilities @property @@ -522,9 +554,9 @@ class SbcPacketSource: from .avdtp import MediaPacket # Import here to avoid a circular reference sequence_number = 0 - timestamp = 0 - frames = [] - frames_size = 0 + timestamp = 0 + frames = [] + frames_size = 0 max_rtp_payload = self.mtu - 12 - 1 # NOTE: this doesn't support frame fragments @@ -532,12 +564,19 @@ class SbcPacketSource: async for frame in sbc_parser.frames: print(frame) - if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16: + if ( + frames_size + len(frame.payload) > max_rtp_payload + or len(frames) == 16 + ): # Need to flush what has been accumulated so far # Emit a packet - sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames]) - packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload) + sbc_payload = bytes([len(frames)]) + b''.join( + [frame.payload for frame in frames] + ) + packet = MediaPacket( + 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload + ) packet.timestamp_seconds = timestamp / frame.sampling_frequency yield packet diff --git a/bumble/att.py b/bumble/att.py index d83b3d14..febd6edb 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -31,6 +31,8 @@ from .hci import * # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + ATT_CID = 0x04 ATT_ERROR_RESPONSE = 0x01 @@ -166,6 +168,8 @@ HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'} UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731 UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 +# fmt: on + # ----------------------------------------------------------------------------- # Utils @@ -196,6 +200,7 @@ class ATT_PDU: ''' See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU ''' + pdu_classes = {} op_code = 0 @@ -274,11 +279,13 @@ class ATT_PDU: # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([ - ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), - ('attribute_handle_in_error', HANDLE_FIELD_SPEC), - ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}) -]) +@ATT_PDU.subclass( + [ + ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), + ('attribute_handle_in_error', HANDLE_FIELD_SPEC), + ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}), + ] +) class ATT_Error_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response @@ -286,9 +293,7 @@ class ATT_Error_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([ - ('client_rx_mtu', 2) -]) +@ATT_PDU.subclass([('client_rx_mtu', 2)]) class ATT_Exchange_MTU_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request @@ -296,9 +301,7 @@ class ATT_Exchange_MTU_Request(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([ - ('server_rx_mtu', 2) -]) +@ATT_PDU.subclass([('server_rx_mtu', 2)]) class ATT_Exchange_MTU_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response @@ -306,10 +309,9 @@ class ATT_Exchange_MTU_Response(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([ - ('starting_handle', HANDLE_FIELD_SPEC), - ('ending_handle', HANDLE_FIELD_SPEC) -]) +@ATT_PDU.subclass( + [('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)] +) class ATT_Find_Information_Request(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request @@ -317,10 +319,7 @@ class ATT_Find_Information_Request(ATT_PDU): # ----------------------------------------------------------------------------- -@ATT_PDU.subclass([ - ('format', 1), - ('information_data', '*') -]) +@ATT_PDU.subclass([('format', 1), ('information_data', '*')]) class ATT_Find_Information_Response(ATT_PDU): ''' See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response @@ -332,7 +331,7 @@ class ATT_Find_Information_Response(ATT_PDU): uuid_size = 2 if self.format == 1 else 16 while offset + uuid_size <= len(self.information_data): handle = struct.unpack_from('> 6) & 0x03 - padding = (data[0] >> 5) & 0x01 - extension = (data[0] >> 4) & 0x01 - csrc_count = data[0] & 0x0F - marker = (data[1] >> 7) & 0x01 - payload_type = data[1] & 0x7F + version = (data[0] >> 6) & 0x03 + padding = (data[0] >> 5) & 0x01 + extension = (data[0] >> 4) & 0x01 + csrc_count = data[0] & 0x0F + marker = (data[1] >> 7) & 0x01 + payload_type = data[1] & 0x7F sequence_number = struct.unpack_from('>H', data, 2)[0] - timestamp = struct.unpack_from('>I', data, 4)[0] - ssrc = struct.unpack_from('>I', data, 8)[0] - csrc_list = [struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)] - payload = data[12 + csrc_count * 4:] + timestamp = struct.unpack_from('>I', data, 4)[0] + ssrc = struct.unpack_from('>I', data, 8)[0] + csrc_list = [ + struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count) + ] + payload = data[12 + csrc_count * 4 :] return MediaPacket( version, @@ -273,7 +276,7 @@ class MediaPacket: ssrc, csrc_list, payload_type, - payload + payload, ) def __init__( @@ -287,27 +290,29 @@ class MediaPacket: ssrc, csrc_list, payload_type, - payload + payload, ): - self.version = version - self.padding = padding - self.extension = extension - self.marker = marker + self.version = version + self.padding = padding + self.extension = extension + self.marker = marker self.sequence_number = sequence_number - self.timestamp = timestamp - self.ssrc = ssrc - self.csrc_list = csrc_list - self.payload_type = payload_type - self.payload = payload + self.timestamp = timestamp + self.ssrc = ssrc + self.csrc_list = csrc_list + self.payload_type = payload_type + self.payload = payload def __bytes__(self): - header = ( - bytes([ - self.version << 6 | self.padding << 5 | self.extension << 4 | len(self.csrc_list), - self.marker << 7 | self.payload_type - ]) + - struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc) - ) + header = bytes( + [ + self.version << 6 + | self.padding << 5 + | self.extension << 4 + | len(self.csrc_list), + self.marker << 7 | self.payload_type, + ] + ) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc) for csrc in self.csrc_list: header += struct.pack('>I', csrc) return header + self.payload @@ -319,13 +324,13 @@ class MediaPacket: # ----------------------------------------------------------------------------- class MediaPacketPump: def __init__(self, packets, clock=RealtimeClock()): - self.packets = packets - self.clock = clock + self.packets = packets + self.clock = clock self.pump_task = None async def start(self, rtp_channel): async def pump_packets(): - start_time = 0 + start_time = 0 start_timestamp = 0 try: @@ -333,7 +338,7 @@ class MediaPacketPump: async for packet in self.packets: # Capture the timestamp of the first packet if start_time == 0: - start_time = self.clock.now() + start_time = self.clock.now() start_timestamp = packet.timestamp_seconds # Wait until we can send @@ -346,7 +351,9 @@ class MediaPacketPump: # Emit rtp_channel.send_pdu(bytes(packet)) - logger.debug(f'{color(">>> sending RTP packet:", "green")} {packet}') + logger.debug( + f'{color(">>> sending RTP packet:", "green")} {packet}' + ) except asyncio.exceptions.CancelledError: logger.debug('pump canceled') @@ -368,67 +375,87 @@ class MessageAssembler: self.reset() def reset(self): - self.transaction_label = 0 - self.message = None - self.message_type = 0 - self.signal_identifier = 0 + self.transaction_label = 0 + self.message = None + self.message_type = 0 + self.signal_identifier = 0 self.number_of_signal_packets = 0 - self.packet_count = 0 + self.packet_count = 0 def on_pdu(self, pdu): self.packet_count += 1 transaction_label = pdu[0] >> 4 - packet_type = (pdu[0] >> 2) & 3 - message_type = pdu[0] & 3 + packet_type = (pdu[0] >> 2) & 3 + message_type = pdu[0] & 3 - logger.debug(f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}') - if packet_type == Protocol.SINGLE_PACKET or packet_type == Protocol.START_PACKET: + logger.debug( + f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}' + ) + if ( + packet_type == Protocol.SINGLE_PACKET + or packet_type == Protocol.START_PACKET + ): if self.message is not None: # The previous message has not been terminated - logger.warning('received a start or single packet when expecting an end or continuation') + logger.warning( + 'received a start or single packet when expecting an end or continuation' + ) self.reset() self.transaction_label = transaction_label self.signal_identifier = pdu[1] & 0x3F - self.message_type = message_type + self.message_type = message_type if packet_type == Protocol.SINGLE_PACKET: self.message = pdu[2:] self.on_message_complete() else: self.number_of_signal_packets = pdu[2] - self.message = pdu[3:] - elif packet_type == Protocol.CONTINUE_PACKET or packet_type == Protocol.END_PACKET: + self.message = pdu[3:] + elif ( + packet_type == Protocol.CONTINUE_PACKET + or packet_type == Protocol.END_PACKET + ): if self.packet_count == 0: logger.warning('unexpected continuation') return if transaction_label != self.transaction_label: - logger.warning(f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}') + logger.warning( + f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}' + ) return if message_type != self.message_type: - logger.warning(f'message type mismatch: expected {self.message_type}, received {message_type}') + logger.warning( + f'message type mismatch: expected {self.message_type}, received {message_type}' + ) return self.message += pdu[1:] if packet_type == Protocol.END_PACKET: if self.packet_count != self.number_of_signal_packets: - logger.warning(f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}') + logger.warning( + f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}' + ) self.reset() return self.on_message_complete() else: if self.packet_count > self.number_of_signal_packets: - logger.warning(f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}') + logger.warning( + f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}' + ) self.reset() return def on_message_complete(self): - message = Message.create(self.signal_identifier, self.message_type, self.message) + message = Message.create( + self.signal_identifier, self.message_type, self.message + ) try: self.callback(self.transaction_label, message) @@ -460,12 +487,14 @@ class ServiceCapabilities: def parse_capabilities(payload): capabilities = [] while payload: - service_category = payload[0] + service_category = payload[0] length_of_service_capabilities = payload[1] - service_capabilities_bytes = payload[2:2 + length_of_service_capabilities] - capabilities.append(ServiceCapabilities.create(service_category, service_capabilities_bytes)) + service_capabilities_bytes = payload[2 : 2 + length_of_service_capabilities] + capabilities.append( + ServiceCapabilities.create(service_category, service_capabilities_bytes) + ) - payload = payload[2 + length_of_service_capabilities:] + payload = payload[2 + length_of_service_capabilities :] return capabilities @@ -473,21 +502,24 @@ class ServiceCapabilities: def serialize_capabilities(capabilities): serialized = b'' for item in capabilities: - serialized += bytes([ - item.service_category, - len(item.service_capabilities_bytes) - ]) + item.service_capabilities_bytes + serialized += ( + bytes([item.service_category, len(item.service_capabilities_bytes)]) + + item.service_capabilities_bytes + ) return serialized def init_from_bytes(self): pass def __init__(self, service_category, service_capabilities_bytes=b''): - self.service_category = service_category + self.service_category = service_category self.service_capabilities_bytes = service_capabilities_bytes def to_string(self, details=[]): - attributes = ','.join([name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + details) + attributes = ','.join( + [name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + + details + ) return f'ServiceCapabilities({attributes})' def __str__(self): @@ -501,31 +533,39 @@ class ServiceCapabilities: # ----------------------------------------------------------------------------- class MediaCodecCapabilities(ServiceCapabilities): def init_from_bytes(self): - self.media_type = self.service_capabilities_bytes[0] - self.media_codec_type = self.service_capabilities_bytes[1] + self.media_type = self.service_capabilities_bytes[0] + self.media_codec_type = self.service_capabilities_bytes[1] self.media_codec_information = self.service_capabilities_bytes[2:] if self.media_codec_type == A2DP_SBC_CODEC_TYPE: - self.media_codec_information = SbcMediaCodecInformation.from_bytes(self.media_codec_information) + self.media_codec_information = SbcMediaCodecInformation.from_bytes( + self.media_codec_information + ) elif self.media_codec_type == A2DP_MPEG_2_4_AAC_CODEC_TYPE: - self.media_codec_information = AacMediaCodecInformation.from_bytes(self.media_codec_information) + self.media_codec_information = AacMediaCodecInformation.from_bytes( + self.media_codec_information + ) elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE: - self.media_codec_information = VendorSpecificMediaCodecInformation.from_bytes(self.media_codec_information) + self.media_codec_information = ( + VendorSpecificMediaCodecInformation.from_bytes( + self.media_codec_information + ) + ) def __init__(self, media_type, media_codec_type, media_codec_information): super().__init__( AVDTP_MEDIA_CODEC_SERVICE_CATEGORY, - bytes([media_type, media_codec_type]) + bytes(media_codec_information) + bytes([media_type, media_codec_type]) + bytes(media_codec_information), ) - self.media_type = media_type - self.media_codec_type = media_codec_type + self.media_type = media_type + self.media_codec_type = media_codec_type self.media_codec_information = media_codec_information def __str__(self): details = [ f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}', f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}', - f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}' + f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}', ] return self.to_string(details) @@ -535,37 +575,33 @@ class EndPointInfo: @staticmethod def from_bytes(payload): return EndPointInfo( - payload[0] >> 2, - payload[0] >> 1 & 1, - payload[1] >> 4, - payload[1] >> 3 & 1 + payload[0] >> 2, payload[0] >> 1 & 1, payload[1] >> 4, payload[1] >> 3 & 1 ) def __bytes__(self): - return bytes([ - self.seid << 2 | self.in_use << 1, - self.media_type << 4 | self.tsep << 3 - ]) + return bytes( + [self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3] + ) def __init__(self, seid, in_use, media_type, tsep): - self.seid = seid - self.in_use = in_use + self.seid = seid + self.in_use = in_use self.media_type = media_type - self.tsep = tsep + self.tsep = tsep # ----------------------------------------------------------------------------- class Message: - COMMAND = 0 - GENERAL_REJECT = 1 + COMMAND = 0 + GENERAL_REJECT = 1 RESPONSE_ACCEPT = 2 RESPONSE_REJECT = 3 MESSAGE_TYPE_NAMES = { - COMMAND: 'COMMAND', - GENERAL_REJECT: 'GENERAL_REJECT', + COMMAND: 'COMMAND', + GENERAL_REJECT: 'GENERAL_REJECT', RESPONSE_ACCEPT: 'RESPONSE_ACCEPT', - RESPONSE_REJECT: 'RESPONSE_REJECT' + RESPONSE_REJECT: 'RESPONSE_REJECT', } subclasses = {} # Subclasses, by signal identifier and message type @@ -603,7 +639,9 @@ class Message: break # Register the subclass - Message.subclasses.setdefault(cls.signal_identifier, {})[cls.message_type] = cls + Message.subclasses.setdefault(cls.signal_identifier, {})[ + cls.message_type + ] = cls return cls @@ -635,7 +673,7 @@ class Message: pass def __init__(self, payload=b''): - self.payload = payload + self.payload = payload def to_string(self, details): base = f'{color(f"{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_{Message.message_type_name(self.message_type)}", "yellow")}' @@ -643,7 +681,11 @@ class Message: if type(details) is str: return f'{base}: {details}' else: - return base + ':\n' + '\n'.join([' ' + color(detail, 'cyan') for detail in details]) + return ( + base + + ':\n' + + '\n'.join([' ' + color(detail, 'cyan') for detail in details]) + ) else: return base @@ -682,9 +724,7 @@ class Simple_Reject(Message): self.payload = bytes([self.error_code]) def __str__(self): - details = [ - f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}' - ] + details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'] return self.to_string(details) @@ -707,11 +747,13 @@ class Discover_Response(Message): self.endpoints = [] endpoint_count = len(self.payload) // 2 for i in range(endpoint_count): - self.endpoints.append(EndPointInfo.from_bytes(self.payload[i * 2:(i + 1) * 2])) + self.endpoints.append( + EndPointInfo.from_bytes(self.payload[i * 2 : (i + 1) * 2]) + ) def __init__(self, endpoints): self.endpoints = endpoints - self.payload = b''.join([bytes(endpoint) for endpoint in endpoints]) + self.payload = b''.join([bytes(endpoint) for endpoint in endpoints]) def __str__(self): details = [] @@ -721,7 +763,7 @@ class Discover_Response(Message): f'ACP SEID: {endpoint.seid}', f' in_use: {endpoint.in_use}', f' media_type: {name_or_number(AVDTP_MEDIA_TYPE_NAMES, endpoint.media_type)}', - f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}' + f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}', ] ) return self.to_string(details) @@ -794,21 +836,22 @@ class Set_Configuration_Command(Message): ''' def init_from_payload(self): - self.acp_seid = self.payload[0] >> 2 - self.int_seid = self.payload[1] >> 2 + self.acp_seid = self.payload[0] >> 2 + self.int_seid = self.payload[1] >> 2 self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[2:]) def __init__(self, acp_seid, int_seid, capabilities): - self.acp_seid = acp_seid - self.int_seid = int_seid + self.acp_seid = acp_seid + self.int_seid = int_seid self.capabilities = capabilities - self.payload = bytes([acp_seid << 2, int_seid << 2]) + ServiceCapabilities.serialize_capabilities(capabilities) + self.payload = bytes( + [acp_seid << 2, int_seid << 2] + ) + ServiceCapabilities.serialize_capabilities(capabilities) def __str__(self): - details = [ - f'ACP SEID: {self.acp_seid}', - f'INT SEID: {self.int_seid}' - ] + [str(capability) for capability in self.capabilities] + details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [ + str(capability) for capability in self.capabilities + ] return self.to_string(details) @@ -829,17 +872,17 @@ class Set_Configuration_Reject(Message): def init_from_payload(self): self.service_category = self.payload[0] - self.error_code = self.payload[1] + self.error_code = self.payload[1] def __init__(self, service_category, error_code): self.service_category = service_category - self.error_code = error_code - self.payload = bytes([service_category, self.error_code]) + self.error_code = error_code + self.payload = bytes([service_category, self.error_code]) def __str__(self): details = [ f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}', - f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}' + f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}', ] return self.to_string(details) @@ -887,7 +930,7 @@ class Reconfigure_Command(Message): ''' def init_from_payload(self): - self.acp_seid = self.payload[0] >> 2 + self.acp_seid = self.payload[0] >> 2 self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[1:]) def __str__(self): @@ -971,18 +1014,18 @@ class Start_Reject(Message): ''' def init_from_payload(self): - self.acp_seid = self.payload[0] >> 2 + self.acp_seid = self.payload[0] >> 2 self.error_code = self.payload[1] def __init__(self, acp_seid, error_code): - self.acp_seid = acp_seid + self.acp_seid = acp_seid self.error_code = error_code - self.payload = bytes([self.acp_seid << 2, self.error_code]) + self.payload = bytes([self.acp_seid << 2, self.error_code]) def __str__(self): details = [ f'acp_seid: {self.acp_seid}', - f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}' + f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}', ] return self.to_string(details) @@ -1095,13 +1138,10 @@ class DelayReport_Command(Message): def init_from_payload(self): self.acp_seid = self.payload[0] >> 2 - self.delay = (self.payload[1] << 8) | (self.payload[2]) + self.delay = (self.payload[1] << 8) | (self.payload[2]) def __str__(self): - return self.to_string([ - f'ACP_SEID: {self.acp_seid}', - f'delay: {self.delay}' - ]) + return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}']) # ----------------------------------------------------------------------------- @@ -1122,16 +1162,16 @@ class DelayReport_Reject(Simple_Reject): # ----------------------------------------------------------------------------- class Protocol: - SINGLE_PACKET = 0 - START_PACKET = 1 + SINGLE_PACKET = 0 + START_PACKET = 1 CONTINUE_PACKET = 2 - END_PACKET = 3 + END_PACKET = 3 PACKET_TYPE_NAMES = { - SINGLE_PACKET: 'SINGLE_PACKET', - START_PACKET: 'START_PACKET', + SINGLE_PACKET: 'SINGLE_PACKET', + START_PACKET: 'START_PACKET', CONTINUE_PACKET: 'CONTINUE_PACKET', - END_PACKET: 'END_PACKET' + END_PACKET: 'END_PACKET', } @staticmethod @@ -1148,18 +1188,18 @@ class Protocol: return protocol def __init__(self, l2cap_channel, version=(1, 3)): - self.l2cap_channel = l2cap_channel - self.version = version - self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER - self.message_assembler = MessageAssembler(self.on_message) - self.transaction_results = [None] * 16 # Futures for up to 16 transactions + self.l2cap_channel = l2cap_channel + self.version = version + self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER + self.message_assembler = MessageAssembler(self.on_message) + self.transaction_results = [None] * 16 # Futures for up to 16 transactions self.transaction_semaphore = asyncio.Semaphore(16) - self.transaction_count = 0 - self.channel_acceptor = None - self.channel_connector = None - self.local_endpoints = [] # Local endpoints, with contiguous seid values - self.remote_endpoints = {} # Remote stream endpoints, by seid - self.streams = {} # Streams, by seid + self.transaction_count = 0 + self.channel_acceptor = None + self.channel_connector = None + self.local_endpoints = [] # Local endpoints, with contiguous seid values + self.remote_endpoints = {} # Remote stream endpoints, by seid + self.streams = {} # Streams, by seid # Register to receive PDUs from the channel l2cap_channel.sink = self.on_pdu @@ -1205,7 +1245,9 @@ class Protocol: response = await self.send_command(Discover_Command()) for endpoint_entry in response.endpoints: - logger.debug(f'getting endpoint capabilities for endpoint {endpoint_entry.seid}') + logger.debug( + f'getting endpoint capabilities for endpoint {endpoint_entry.seid}' + ) get_capabilities_response = await self.get_capabilities(endpoint_entry.seid) endpoint = DiscoveredStreamEndPoint( self, @@ -1213,7 +1255,7 @@ class Protocol: endpoint_entry.media_type, endpoint_entry.tsep, endpoint_entry.in_use, - get_capabilities_response.capabilities + get_capabilities_response.capabilities, ) self.remote_endpoints[endpoint_entry.seid] = endpoint @@ -1221,14 +1263,27 @@ class Protocol: def find_remote_sink_by_codec(self, media_type, codec_type): for endpoint in self.remote_endpoints.values(): - if not endpoint.in_use and endpoint.media_type == media_type and endpoint.tsep == AVDTP_TSEP_SNK: + if ( + not endpoint.in_use + and endpoint.media_type == media_type + and endpoint.tsep == AVDTP_TSEP_SNK + ): has_media_transport = False has_codec = False for capabilities in endpoint.capabilities: - if capabilities.service_category == AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY: + if ( + capabilities.service_category + == AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY + ): has_media_transport = True - elif capabilities.service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY: - if capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE and capabilities.media_codec_type == codec_type: + elif ( + capabilities.service_category + == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY + ): + if ( + capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE + and capabilities.media_codec_type == codec_type + ): has_codec = True if has_media_transport and has_codec: return endpoint @@ -1237,7 +1292,9 @@ class Protocol: self.message_assembler.on_pdu(pdu) def on_message(self, transaction_label, message): - logger.debug(f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}') + logger.debug( + f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}' + ) # Check that the identifier is not reserved if message.signal_identifier == 0: @@ -1245,7 +1302,10 @@ class Protocol: return # Check that the identifier is valid - if message.signal_identifier < 0 or message.signal_identifier > AVDTP_DELAYREPORT: + if ( + message.signal_identifier < 0 + or message.signal_identifier > AVDTP_DELAYREPORT + ): logger.warning('!!! invalid signal identifier') self.send_message(transaction_label, General_Reject()) @@ -1258,7 +1318,9 @@ class Protocol: response = handler(message) self.send_message(transaction_label, response) except Exception as error: - logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') + logger.warning( + f'{color("!!! Exception in handler:", "red")} {error}' + ) else: logger.warning('unhandled command') else: @@ -1281,8 +1343,12 @@ class Protocol: logger.debug(color('<<< L2CAP channel open', 'magenta')) def send_message(self, transaction_label, message): - logger.debug(f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}') - max_fragment_size = self.l2cap_channel.mtu - 3 # Enough space for a 3-byte start packet header + logger.debug( + f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}' + ) + max_fragment_size = ( + self.l2cap_channel.mtu - 3 + ) # Enough space for a 3-byte start packet header payload = message.payload if len(payload) + 2 <= self.l2cap_channel.mtu: # Fits in a single packet @@ -1292,13 +1358,19 @@ class Protocol: done = False while not done: - first_header_byte = transaction_label << 4 | packet_type << 2 | message.message_type + first_header_byte = ( + transaction_label << 4 | packet_type << 2 | message.message_type + ) if packet_type == self.SINGLE_PACKET: header = bytes([first_header_byte, message.signal_identifier]) elif packet_type == self.START_PACKET: - packet_count = (max_fragment_size - 1 + len(payload)) // max_fragment_size - header = bytes([first_header_byte, message.signal_identifier, packet_count]) + packet_count = ( + max_fragment_size - 1 + len(payload) + ) // max_fragment_size + header = bytes( + [first_header_byte, message.signal_identifier, packet_count] + ) else: header = bytes([first_header_byte]) @@ -1308,7 +1380,11 @@ class Protocol: # Prepare for the next packet payload = payload[max_fragment_size:] if payload: - packet_type = self.CONTINUE_PACKET if payload > max_fragment_size else self.END_PACKET + packet_type = ( + self.CONTINUE_PACKET + if payload > max_fragment_size + else self.END_PACKET + ) else: done = True @@ -1322,7 +1398,10 @@ class Protocol: response = await transaction_result # Check for errors - if response.message_type == Message.GENERAL_REJECT or response.message_type == Message.RESPONSE_REJECT: + if ( + response.message_type == Message.GENERAL_REJECT + or response.message_type == Message.RESPONSE_REJECT + ): raise ProtocolError(response.error_code, 'avdtp') return response @@ -1340,7 +1419,7 @@ class Protocol: self.transaction_count += 1 return (transaction_label, transaction_result) - assert(False) # Should never reach this + assert False # Should never reach this async def get_capabilities(self, seid): if self.version > (1, 2): @@ -1349,7 +1428,9 @@ class Protocol: return await self.send_command(Get_Capabilities_Command(seid)) async def set_configuration(self, acp_seid, int_seid, capabilities): - return await self.send_command(Set_Configuration_Command(acp_seid, int_seid, capabilities)) + return await self.send_command( + Set_Configuration_Command(acp_seid, int_seid, capabilities) + ) async def get_configuration(self, seid): response = await self.send_command(Get_Configuration_Command(seid)) @@ -1537,6 +1618,7 @@ class Listener(EventEmitter): server = Protocol(channel, self.version) self.set_server(channel.connection, server) self.emit('connection', server) + channel.on('open', on_channel_open) @@ -1562,8 +1644,7 @@ class Stream: raise InvalidStateError('current state is not IDLE') await self.remote_endpoint.set_configuration( - self.local_endpoint.seid, - self.local_endpoint.configuration + self.local_endpoint.seid, self.local_endpoint.configuration ) self.change_state(AVDTP_CONFIGURED_STATE) @@ -1639,7 +1720,11 @@ class Stream: self.change_state(AVDTP_CONFIGURED_STATE) def on_get_configuration_command(self, configuration): - if self.state not in {AVDTP_CONFIGURED_STATE, AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}: + if self.state not in { + AVDTP_CONFIGURED_STATE, + AVDTP_OPEN_STATE, + AVDTP_STREAMING_STATE, + }: return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR) return self.local_endpoint.on_get_configuration_command(configuration) @@ -1718,7 +1803,7 @@ class Stream: def on_l2cap_connection(self, channel): logger.debug(color('<<< stream channel connected', 'magenta')) self.rtp_channel = channel - channel.on('open', self.on_l2cap_channel_open) + channel.on('open', self.on_l2cap_channel_open) channel.on('close', self.on_l2cap_channel_close) # We don't need more channels @@ -1744,11 +1829,11 @@ class Stream: remote_endpoint must be a subclass of StreamEndPointProxy ''' - self.protocol = protocol - self.local_endpoint = local_endpoint + self.protocol = protocol + self.local_endpoint = local_endpoint self.remote_endpoint = remote_endpoint - self.rtp_channel = None - self.state = AVDTP_IDLE_STATE + self.rtp_channel = None + self.state = AVDTP_IDLE_STATE local_endpoint.stream = self local_endpoint.in_use = 1 @@ -1760,38 +1845,36 @@ class Stream: # ----------------------------------------------------------------------------- class StreamEndPoint: def __init__(self, seid, media_type, tsep, in_use, capabilities): - self.seid = seid - self.media_type = media_type - self.tsep = tsep - self.in_use = in_use + self.seid = seid + self.media_type = media_type + self.tsep = tsep + self.in_use = in_use self.capabilities = capabilities def __str__(self): - return '\n'.join([ - 'SEP(', - f' seid={self.seid}', - f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}', - f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}', - f' in_use={self.in_use}', - ' capabilities=[', - '\n'.join([f' {x}' for x in self.capabilities]), - ' ]', - ')' - ]) + return '\n'.join( + [ + 'SEP(', + f' seid={self.seid}', + f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}', + f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}', + f' in_use={self.in_use}', + ' capabilities=[', + '\n'.join([f' {x}' for x in self.capabilities]), + ' ]', + ')', + ] + ) # ----------------------------------------------------------------------------- class StreamEndPointProxy: def __init__(self, protocol, seid): - self.seid = seid + self.seid = seid self.protocol = protocol async def set_configuration(self, int_seid, configuration): - return await self.protocol.set_configuration( - self.seid, - int_seid, - configuration - ) + return await self.protocol.set_configuration(self.seid, int_seid, configuration) async def open(self): return await self.protocol.open(self.seid) @@ -1818,11 +1901,13 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy): # ----------------------------------------------------------------------------- class LocalStreamEndPoint(StreamEndPoint): - def __init__(self, protocol, seid, media_type, tsep, capabilities, configuration=[]): + def __init__( + self, protocol, seid, media_type, tsep, capabilities, configuration=[] + ): super().__init__(seid, media_type, tsep, 0, capabilities) - self.protocol = protocol + self.protocol = protocol self.configuration = configuration - self.stream = None + self.stream = None async def start(self): pass @@ -1866,9 +1951,17 @@ class LocalSource(LocalStreamEndPoint, EventEmitter): def __init__(self, protocol, seid, codec_capabilities, packet_pump): capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), - codec_capabilities + codec_capabilities, ] - LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SRC, capabilities, capabilities) + LocalStreamEndPoint.__init__( + self, + protocol, + seid, + codec_capabilities.media_type, + AVDTP_TSEP_SRC, + capabilities, + capabilities, + ) EventEmitter.__init__(self) self.packet_pump = packet_pump @@ -1901,9 +1994,16 @@ class LocalSink(LocalStreamEndPoint, EventEmitter): def __init__(self, protocol, seid, codec_capabilities): capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), - codec_capabilities + codec_capabilities, ] - LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SNK, capabilities) + LocalStreamEndPoint.__init__( + self, + protocol, + seid, + codec_capabilities.media_type, + AVDTP_TSEP_SNK, + capabilities, + ) EventEmitter.__init__(self) def on_set_configuration_command(self, configuration): @@ -1917,5 +2017,7 @@ class LocalSink(LocalStreamEndPoint, EventEmitter): def on_avdtp_packet(self, packet): rtp_packet = MediaPacket.from_bytes(packet) - logger.debug(f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}') + logger.debug( + f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}' + ) self.emit('rtp_packet', rtp_packet) diff --git a/bumble/bridge.py b/bumble/bridge.py index 2b4cd94f..ac3ba8a5 100644 --- a/bumble/bridge.py +++ b/bumble/bridge.py @@ -30,10 +30,10 @@ logger = logging.getLogger(__name__) class HCI_Bridge: class Forwarder: def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace): - self.hci_sink = hci_sink + self.hci_sink = hci_sink self.sender_hci_sink = sender_hci_sink - self.packet_filter = packet_filter - self.trace = trace + self.packet_filter = packet_filter + self.trace = trace def on_packet(self, packet): # Convert the packet bytes to an object @@ -61,15 +61,15 @@ class HCI_Bridge: hci_host_sink, hci_controller_source, hci_controller_sink, - host_to_controller_filter = None, - controller_to_host_filter = None + host_to_controller_filter=None, + controller_to_host_filter=None, ): tracer = PacketTracer(emit_message=logger.info) host_to_controller_forwarder = HCI_Bridge.Forwarder( hci_controller_sink, hci_host_sink, host_to_controller_filter, - lambda packet: tracer.trace(packet, 0) + lambda packet: tracer.trace(packet, 0), ) hci_host_source.set_packet_sink(host_to_controller_forwarder) @@ -77,6 +77,6 @@ class HCI_Bridge: hci_host_sink, hci_controller_sink, controller_to_host_filter, - lambda packet: tracer.trace(packet, 1) + lambda packet: tracer.trace(packet, 1), ) hci_controller_source.set_packet_sink(controller_to_host_forwarder) diff --git a/bumble/company_ids.py b/bumble/company_ids.py index c9c9d1aa..d571e129 100644 --- a/bumble/company_ids.py +++ b/bumble/company_ids.py @@ -2704,5 +2704,5 @@ COMPANY_IDENTIFIERS = { 0x0A7C: "WAFERLOCK", 0x0A7D: "Freedman Electronics Pty Ltd", 0x0A7E: "Keba AG", - 0x0A7F: "Intuity Medical" -} \ No newline at end of file + 0x0A7F: "Intuity Medical", +} diff --git a/bumble/controller.py b/bumble/controller.py index 7982e9bd..41e6b165 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -39,67 +39,79 @@ class DataObject: # ----------------------------------------------------------------------------- class Connection: def __init__(self, controller, handle, role, peer_address, link): - self.controller = controller - self.handle = handle - self.role = role + self.controller = controller + self.handle = handle + self.role = role self.peer_address = peer_address - self.link = link - self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) + self.link = link + self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) def on_hci_acl_data_packet(self, packet): self.assembler.feed_packet(packet) - self.controller.send_hci_packet(HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])) + self.controller.send_hci_packet( + HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)]) + ) def on_acl_pdu(self, data): if self.link: - self.link.send_acl_data(self.controller.random_address, self.peer_address, data) + self.link.send_acl_data( + self.controller.random_address, self.peer_address, data + ) # ----------------------------------------------------------------------------- class Controller: - def __init__(self, name, host_source = None, host_sink = None, link = None): - self.name = name + def __init__(self, name, host_source=None, host_sink=None, link=None): + self.name = name self.hci_sink = None - self.link = link + self.link = link - self.central_connections = {} # Connections where this controller is the central - self.peripheral_connections = {} # Connections where this controller is the peripheral + self.central_connections = ( + {} + ) # Connections where this controller is the central + self.peripheral_connections = ( + {} + ) # Connections where this controller is the peripheral - self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 - self.hci_revision = 0 - self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0 - self.lmp_subversion = 0 - self.lmp_features = bytes.fromhex('0000000060000000') # BR/EDR Not Supported, LE Supported (Controller) - self.manufacturer_name = 0xFFFF - self.hc_le_data_packet_length = 27 + self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 + self.hci_revision = 0 + self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0 + self.lmp_subversion = 0 + self.lmp_features = bytes.fromhex( + '0000000060000000' + ) # BR/EDR Not Supported, LE Supported (Controller) + self.manufacturer_name = 0xFFFF + self.hc_le_data_packet_length = 27 self.hc_total_num_le_data_packets = 64 - self.supported_commands = bytes.fromhex('2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000') - self.le_features = bytes.fromhex('ff49010000000000') - self.le_states = bytes.fromhex('ffff3fffff030000') + self.supported_commands = bytes.fromhex( + '2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000' + ) + self.le_features = bytes.fromhex('ff49010000000000') + self.le_states = bytes.fromhex('ffff3fffff030000') self.advertising_channel_tx_power = 0 - self.filter_accept_list_size = 8 - self.resolving_list_size = 8 - self.supported_max_tx_octets = 27 - self.supported_max_tx_time = 10000 # microseconds - self.supported_max_rx_octets = 27 - self.supported_max_rx_time = 10000 # microseconds - self.suggested_max_tx_octets = 27 - self.suggested_max_tx_time = 0x0148 # microseconds - self.default_phy = bytes([0, 0, 0]) - self.le_scan_type = 0 - self.le_scan_interval = 0x10 - self.le_scan_window = 0x10 - self.le_scan_enable = 0 - self.le_scan_own_address_type = Address.RANDOM_DEVICE_ADDRESS - self.le_scanning_filter_policy = 0 - self.le_scan_response_data = None - self.le_address_resolution = False - self.le_rpa_timeout = 0 - self.sync_flow_control = False - self.local_name = 'Bumble' + self.filter_accept_list_size = 8 + self.resolving_list_size = 8 + self.supported_max_tx_octets = 27 + self.supported_max_tx_time = 10000 # microseconds + self.supported_max_rx_octets = 27 + self.supported_max_rx_time = 10000 # microseconds + self.suggested_max_tx_octets = 27 + self.suggested_max_tx_time = 0x0148 # microseconds + self.default_phy = bytes([0, 0, 0]) + self.le_scan_type = 0 + self.le_scan_interval = 0x10 + self.le_scan_window = 0x10 + self.le_scan_enable = 0 + self.le_scan_own_address_type = Address.RANDOM_DEVICE_ADDRESS + self.le_scanning_filter_policy = 0 + self.le_scan_response_data = None + self.le_address_resolution = False + self.le_rpa_timeout = 0 + self.sync_flow_control = False + self.local_name = 'Bumble' - self.advertising_interval = 2000 # Fixed for now - self.advertising_data = None + self.advertising_interval = 2000 # Fixed for now + self.advertising_data = None self.advertising_timer_handle = None self._random_address = Address('00:00:00:00:00:00') @@ -162,7 +174,9 @@ class Controller: self.on_hci_packet(HCI_Packet.from_bytes(packet)) def on_hci_packet(self, packet): - logger.debug(f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}') + logger.debug( + f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}' + ) # If the packet is a command, invoke the handler for this packet if packet.hci_packet_type == HCI_COMMAND_PACKET: @@ -179,11 +193,13 @@ class Controller: handler = getattr(self, handler_name, self.on_hci_command) result = handler(command) if type(result) is bytes: - self.send_hci_packet(HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = command.op_code, - return_parameters = result - )) + self.send_hci_packet( + HCI_Command_Complete_Event( + num_hci_command_packets=1, + command_opcode=command.op_code, + return_parameters=result, + ) + ) def on_hci_event_packet(self, event): logger.warning('!!! unexpected event packet') @@ -192,14 +208,18 @@ class Controller: # Look for the connection to which this data belongs connection = self.find_connection_by_handle(packet.connection_handle) if connection is None: - logger.warning(f'!!! no connection for handle 0x{packet.connection_handle:04X}') + logger.warning( + f'!!! no connection for handle 0x{packet.connection_handle:04X}' + ) return # Pass the packet to the connection connection.on_hci_acl_data_packet(packet) def send_hci_packet(self, packet): - logger.debug(f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}') + logger.debug( + f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}' + ) if self.host: self.host.on_packet(packet.to_bytes()) @@ -215,8 +235,7 @@ class Controller: handle = 0 max_handle = 0 for connection in itertools.chain( - self.central_connections.values(), - self.peripheral_connections.values() + self.central_connections.values(), self.peripheral_connections.values() ): max_handle = max(max_handle, connection.handle) if connection.handle == handle: @@ -225,12 +244,13 @@ class Controller: return handle def find_connection_by_address(self, address): - return self.central_connections.get(address) or self.peripheral_connections.get(address) + return self.central_connections.get(address) or self.peripheral_connections.get( + address + ) def find_connection_by_handle(self, handle): for connection in itertools.chain( - self.central_connections.values(), - self.peripheral_connections.values() + self.central_connections.values(), self.peripheral_connections.values() ): if connection.handle == handle: return connection @@ -253,22 +273,26 @@ class Controller: connection = self.peripheral_connections.get(peer_address) if connection is None: connection_handle = self.allocate_connection_handle() - connection = Connection(self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link) + connection = Connection( + self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link + ) self.peripheral_connections[peer_address] = connection logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') # Then say that the connection has completed - self.send_hci_packet(HCI_LE_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = connection.handle, - role = connection.role, - peer_address_type = peer_address_type, - peer_address = peer_address, - connection_interval = 10, # FIXME - peripheral_latency = 0, # FIXME - supervision_timeout = 10, # FIXME - central_clock_accuracy = 7 # FIXME - )) + self.send_hci_packet( + HCI_LE_Connection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=connection.handle, + role=connection.role, + peer_address_type=peer_address_type, + peer_address=peer_address, + connection_interval=10, # FIXME + peripheral_latency=0, # FIXME + supervision_timeout=10, # FIXME + central_clock_accuracy=7, # FIXME + ) + ) def on_link_central_disconnected(self, peer_address, reason): ''' @@ -277,18 +301,22 @@ class Controller: # Send a disconnection complete event if connection := self.peripheral_connections.get(peer_address): - self.send_hci_packet(HCI_Disconnection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = connection.handle, - reason = reason - )) + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=connection.handle, + reason=reason, + ) + ) # Remove the connection del self.peripheral_connections[peer_address] else: logger.warn(f'!!! No peripheral connection found for {peer_address}') - def on_link_peripheral_connection_complete(self, le_create_connection_command, status): + def on_link_peripheral_connection_complete( + self, le_create_connection_command, status + ): ''' Called by the link when a connection has been made or has failed to be made ''' @@ -300,29 +328,29 @@ class Controller: if connection is None: connection_handle = self.allocate_connection_handle() connection = Connection( - self, - connection_handle, - BT_CENTRAL_ROLE, - peer_address, - self.link + self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link ) self.central_connections[peer_address] = connection - logger.debug(f'New CENTRAL connection handle: 0x{connection_handle:04X}') + logger.debug( + f'New CENTRAL connection handle: 0x{connection_handle:04X}' + ) else: connection = None # Say that the connection has completed - self.send_hci_packet(HCI_LE_Connection_Complete_Event( - status = status, - connection_handle = connection.handle if connection else 0, - role = BT_CENTRAL_ROLE, - peer_address_type = le_create_connection_command.peer_address_type, - peer_address = le_create_connection_command.peer_address, - connection_interval = le_create_connection_command.connection_interval_min, - peripheral_latency = le_create_connection_command.max_latency, - supervision_timeout = le_create_connection_command.supervision_timeout, - central_clock_accuracy = 0 - )) + self.send_hci_packet( + HCI_LE_Connection_Complete_Event( + status=status, + connection_handle=connection.handle if connection else 0, + role=BT_CENTRAL_ROLE, + peer_address_type=le_create_connection_command.peer_address_type, + peer_address=le_create_connection_command.peer_address, + connection_interval=le_create_connection_command.connection_interval_min, + peripheral_latency=le_create_connection_command.max_latency, + supervision_timeout=le_create_connection_command.supervision_timeout, + central_clock_accuracy=0, + ) + ) def on_link_peripheral_disconnection_complete(self, disconnection_command, status): ''' @@ -330,14 +358,18 @@ class Controller: ''' # Send a disconnection complete event - self.send_hci_packet(HCI_Disconnection_Complete_Event( - status = status, - connection_handle = disconnection_command.connection_handle, - reason = disconnection_command.reason - )) + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=status, + connection_handle=disconnection_command.connection_handle, + reason=disconnection_command.reason, + ) + ) # Remove the connection - if connection := self.find_central_connection_by_handle(disconnection_command.connection_handle): + if connection := self.find_central_connection_by_handle( + disconnection_command.connection_handle + ): logger.debug(f'CENTRAL Connection removed: {connection}') del self.central_connections[connection.peer_address] @@ -348,11 +380,13 @@ class Controller: # Send a disconnection complete event if connection := self.central_connections.get(peer_address): - self.send_hci_packet(HCI_Disconnection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = connection.handle, - reason = HCI_CONNECTION_TIMEOUT_ERROR - )) + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=connection.handle, + reason=HCI_CONNECTION_TIMEOUT_ERROR, + ) + ) # Remove the connection del self.central_connections[peer_address] @@ -364,9 +398,7 @@ class Controller: if connection := self.find_connection_by_address(peer_address): self.send_hci_packet( HCI_Encryption_Change_Event( - status = 0, - connection_handle = connection.handle, - encryption_enabled = 1 + status=0, connection_handle=connection.handle, encryption_enabled=1 ) ) @@ -390,22 +422,22 @@ class Controller: # Send a scan report report = HCI_Object( HCI_LE_Advertising_Report_Event.REPORT_FIELDS, - event_type = HCI_LE_Advertising_Report_Event.ADV_IND, - address_type = sender_address.address_type, - address = sender_address, - data = data, - rssi = -50 + event_type=HCI_LE_Advertising_Report_Event.ADV_IND, + address_type=sender_address.address_type, + address=sender_address, + data=data, + rssi=-50, ) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) # Simulate a scan response report = HCI_Object( HCI_LE_Advertising_Report_Event.REPORT_FIELDS, - event_type = HCI_LE_Advertising_Report_Event.SCAN_RSP, - address_type = sender_address.address_type, - address = sender_address, - data = data, - rssi = -50 + event_type=HCI_LE_Advertising_Report_Event.SCAN_RSP, + address_type=sender_address.address_type, + address=sender_address, + data=data, + rssi=-50, ) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) @@ -414,14 +446,18 @@ class Controller: ############################################################ def on_advertising_timer_fired(self): self.send_advertising_data() - self.advertising_timer_handle = asyncio.get_running_loop().call_later(self.advertising_interval / 1000.0, self.on_advertising_timer_fired) + self.advertising_timer_handle = asyncio.get_running_loop().call_later( + self.advertising_interval / 1000.0, self.on_advertising_timer_fired + ) def start_advertising(self): # Stop any ongoing advertising before we start again self.stop_advertising() # Advertise now - self.advertising_timer_handle = asyncio.get_running_loop().call_soon(self.on_advertising_timer_fired) + self.advertising_timer_handle = asyncio.get_running_loop().call_soon( + self.on_advertising_timer_fired + ) def stop_advertising(self): if self.advertising_timer_handle is not None: @@ -455,14 +491,20 @@ class Controller: See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command ''' # First, say that the disconnection is pending - self.send_hci_packet(HCI_Command_Status_Event( - status = HCI_COMMAND_STATUS_PENDING, - num_hci_command_packets = 1, - command_opcode = command.op_code - )) + self.send_hci_packet( + HCI_Command_Status_Event( + status=HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) # Notify the link of the disconnection - if not (connection := self.find_central_connection_by_handle(command.connection_handle)): + if not ( + connection := self.find_central_connection_by_handle( + command.connection_handle + ) + ): logger.warn('connection not found') return @@ -590,7 +632,7 @@ class Controller: self.hci_revision, self.lmp_version, self.manufacturer_name, - self.lmp_subversion + self.lmp_subversion, ) def on_hci_read_local_supported_commands_command(self, command): @@ -609,7 +651,11 @@ class Controller: ''' See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command ''' - bd_addr = self._public_address.to_bytes() if self._public_address is not None else bytes(6) + bd_addr = ( + self._public_address.to_bytes() + if self._public_address is not None + else bytes(6) + ) return bytes([HCI_SUCCESS]) + bd_addr def on_hci_le_set_event_mask_command(self, command): @@ -623,10 +669,12 @@ class Controller: ''' See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command ''' - return struct.pack('> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F)) + return ( + (class_of_device >> 13 & 0x7FF), + (class_of_device >> 8 & 0x1F), + (class_of_device >> 2 & 0x3F), + ) @staticmethod def pack_class_of_device(service_classes, major_device_class, minor_device_class): @@ -542,7 +579,9 @@ class DeviceClass: @staticmethod def service_class_labels(service_class_flags): - return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS) + return bit_flags_to_strings( + service_class_flags, DeviceClass.SERVICE_CLASS_LABELS + ) @staticmethod def major_device_class_name(device_class): @@ -560,6 +599,8 @@ class DeviceClass: # Advertising Data # ----------------------------------------------------------------------------- class AdvertisingData: + # fmt: off + # This list is only partial, it still needs to be filled in from the spec FLAGS = 0x01 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 @@ -671,7 +712,9 @@ class AdvertisingData: BR_EDR_CONTROLLER_FLAG = 0x08 BR_EDR_HOST_FLAG = 0x10 - def __init__(self, ad_structures = []): + # fmt: on + + def __init__(self, ad_structures=[]): self.ad_structures = ad_structures[:] @staticmethod @@ -682,19 +725,17 @@ class AdvertisingData: @staticmethod def flags_to_string(flags, short=False): - flag_names = [ - 'LE Limited', - 'LE General', - 'No BR/EDR', - 'BR/EDR C', - 'BR/EDR H' - ] if short else [ - 'LE Limited Discoverable Mode', - 'LE General Discoverable Mode', - 'BR/EDR Not Supported', - 'Simultaneous LE and BR/EDR (Controller)', - 'Simultaneous LE and BR/EDR (Host)' - ] + flag_names = ( + ['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H'] + if short + else [ + 'LE Limited Discoverable Mode', + 'LE General Discoverable Mode', + 'BR/EDR Not Supported', + 'Simultaneous LE and BR/EDR (Controller)', + 'Simultaneous LE and BR/EDR (Host)', + ] + ) return ','.join(bit_flags_to_strings(flags, flag_names)) @staticmethod @@ -702,16 +743,18 @@ class AdvertisingData: uuids = [] offset = 0 while (uuid_size * (offset + 1)) <= len(ad_data): - uuids.append(UUID.from_bytes(ad_data[offset:offset + uuid_size])) + uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size])) offset += uuid_size return uuids @staticmethod def uuid_list_to_string(ad_data, uuid_size): - return ', '.join([ - str(uuid) - for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size) - ]) + return ', '.join( + [ + str(uuid) + for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size) + ] + ) @staticmethod def ad_data_to_string(ad_type, ad_data): @@ -776,19 +819,19 @@ class AdvertisingData: if ad_type in { AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, - AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS + AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS, }: return AdvertisingData.uuid_list_to_objects(ad_data, 2) elif ad_type in { AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, - AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS + AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS, }: return AdvertisingData.uuid_list_to_objects(ad_data, 4) elif ad_type in { AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, - AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS + AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS, }: return AdvertisingData.uuid_list_to_objects(ad_data, 16) elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: @@ -800,17 +843,14 @@ class AdvertisingData: elif ad_type in { AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.COMPLETE_LOCAL_NAME, - AdvertisingData.URI + AdvertisingData.URI, }: return ad_data.decode("utf-8") - elif ad_type in { - AdvertisingData.TX_POWER_LEVEL, - AdvertisingData.FLAGS - }: + elif ad_type in {AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS}: return ad_data[0] elif ad_type in { AdvertisingData.APPEARANCE, - AdvertisingData.ADVERTISING_INTERVAL + AdvertisingData.ADVERTISING_INTERVAL, }: return struct.unpack(' 0: ad_type = data[offset] - ad_data = data[offset + 1:offset + length] + ad_data = data[offset + 1 : offset + length] self.ad_structures.append((ad_type, ad_data)) offset += length @@ -840,19 +880,33 @@ class AdvertisingData: If return_all is True, returns a (possibly empty) list of matches, else returns the first entry, or None if no structure matches. ''' + def process_ad_data(ad_data): return ad_data if raw else self.ad_data_to_object(type_id, ad_data) if return_all: - return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] + return [ + process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id + ] else: - return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None) + return next( + ( + process_ad_data(ad[1]) + for ad in self.ad_structures + if ad[0] == type_id + ), + None, + ) def __bytes__(self): - return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]) + return b''.join( + [bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures] + ) def to_string(self, separator=', '): - return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]) + return separator.join( + [AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures] + ) def __str__(self): return self.to_string() @@ -864,7 +918,7 @@ class AdvertisingData: class ConnectionParameters: def __init__(self, connection_interval, peripheral_latency, supervision_timeout): self.connection_interval = connection_interval - self.peripheral_latency = peripheral_latency + self.peripheral_latency = peripheral_latency self.supervision_timeout = supervision_timeout def __str__(self): diff --git a/bumble/crypto.py b/bumble/crypto.py index 4f134765..9cf8e3d6 100644 --- a/bumble/crypto.py +++ b/bumble/crypto.py @@ -24,19 +24,16 @@ import logging import operator import platform + if platform.system() != 'Emscripten': import secrets - from cryptography.hazmat.primitives.ciphers import ( - Cipher, - algorithms, - modes - ) + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.asymmetric.ec import ( generate_private_key, ECDH, EllipticCurvePublicNumbers, EllipticCurvePrivateNumbers, - SECP256R1 + SECP256R1, ) from cryptography.hazmat.primitives import cmac else: @@ -66,16 +63,26 @@ class EccKey: d = int.from_bytes(d_bytes, byteorder='big', signed=False) x = int.from_bytes(x_bytes, byteorder='big', signed=False) y = int.from_bytes(y_bytes, byteorder='big', signed=False) - private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key() + private_key = EllipticCurvePrivateNumbers( + d, EllipticCurvePublicNumbers(x, y, SECP256R1()) + ).private_key() return cls(private_key) @property def x(self): - return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big') + return ( + self.private_key.public_key() + .public_numbers() + .x.to_bytes(32, byteorder='big') + ) @property def y(self): - return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big') + return ( + self.private_key.public_key() + .public_numbers() + .y.to_bytes(32, byteorder='big') + ) def dh(self, public_key_x, public_key_y): x = int.from_bytes(public_key_x, byteorder='big', signed=False) @@ -92,7 +99,7 @@ class EccKey: # ----------------------------------------------------------------------------- def xor(x, y): - assert(len(x) == len(y)) + assert len(x) == len(y) return bytes(map(operator.xor, x, y)) @@ -165,7 +172,11 @@ def f4(u, v, x, z): ''' See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4 ''' - return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x))))) + return bytes( + reversed( + aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x))) + ) + ) # ----------------------------------------------------------------------------- @@ -177,28 +188,36 @@ def f5(w, n1, n2, a1, a2): ''' salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE') t = aes_cmac(bytes(reversed(w)), salt) - key_id = bytes([0x62, 0x74, 0x6c, 0x65]) + key_id = bytes([0x62, 0x74, 0x6C, 0x65]) return ( - bytes(reversed(aes_cmac( - bytes([0]) + - key_id + - bytes(reversed(n1)) + - bytes(reversed(n2)) + - bytes(reversed(a1)) + - bytes(reversed(a2)) + - bytes([1, 0]), - t - ))), - bytes(reversed(aes_cmac( - bytes([1]) + - key_id + - bytes(reversed(n1)) + - bytes(reversed(n2)) + - bytes(reversed(a1)) + - bytes(reversed(a2)) + - bytes([1, 0]), - t - ))) + bytes( + reversed( + aes_cmac( + bytes([0]) + + key_id + + bytes(reversed(n1)) + + bytes(reversed(n2)) + + bytes(reversed(a1)) + + bytes(reversed(a2)) + + bytes([1, 0]), + t, + ) + ) + ), + bytes( + reversed( + aes_cmac( + bytes([1]) + + key_id + + bytes(reversed(n1)) + + bytes(reversed(n2)) + + bytes(reversed(a1)) + + bytes(reversed(a2)) + + bytes([1, 0]), + t, + ) + ) + ), ) @@ -207,15 +226,19 @@ def f6(w, n1, n2, r, io_cap, a1, a2): ''' See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6 ''' - return bytes(reversed(aes_cmac( - bytes(reversed(n1)) + - bytes(reversed(n2)) + - bytes(reversed(r)) + - bytes(reversed(io_cap)) + - bytes(reversed(a1)) + - bytes(reversed(a2)), - bytes(reversed(w)) - ))) + return bytes( + reversed( + aes_cmac( + bytes(reversed(n1)) + + bytes(reversed(n2)) + + bytes(reversed(r)) + + bytes(reversed(io_cap)) + + bytes(reversed(a1)) + + bytes(reversed(a2)), + bytes(reversed(w)), + ) + ) + ) # ----------------------------------------------------------------------------- @@ -224,10 +247,14 @@ def g2(u, v, x, y): See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2 ''' return int.from_bytes( - aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:], - byteorder='big' + aes_cmac( + bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), + bytes(reversed(x)), + )[-4:], + byteorder='big', ) + # ----------------------------------------------------------------------------- def h6(w, key_id): ''' @@ -235,6 +262,7 @@ def h6(w, key_id): ''' return aes_cmac(key_id, w) + # ----------------------------------------------------------------------------- def h7(salt, w): ''' diff --git a/bumble/device.py b/bumble/device.py index ea90cac9..5a16392b 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -27,7 +27,12 @@ from .host import Host from .gatt import * from .gap import GenericAccessService from .core import AdvertisingData, BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE -from .utils import AsyncRunner, CompositeEventEmitter, setup_event_forwarding, composite_listener +from .utils import ( + AsyncRunner, + CompositeEventEmitter, + setup_event_forwarding, + composite_listener, +) from . import gatt_client from . import gatt_server from . import smp @@ -43,6 +48,8 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + DEVICE_MIN_SCAN_INTERVAL = 25 DEVICE_MAX_SCAN_INTERVAL = 10240 DEVICE_MIN_SCAN_WINDOW = 25 @@ -73,6 +80,8 @@ DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONN DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS +# fmt: on + # ----------------------------------------------------------------------------- # Classes @@ -80,8 +89,10 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN # ----------------------------------------------------------------------------- class Advertisement: - TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE - RSSI_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE + TX_POWER_NOT_AVAILABLE = ( + HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE + ) + RSSI_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE @classmethod def from_advertising_report(cls, report): @@ -93,36 +104,36 @@ class Advertisement: def __init__( self, address, - rssi = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE, - is_legacy = False, - is_anonymous = False, - is_connectable = False, - is_directed = False, - is_scannable = False, - is_scan_response = False, - is_complete = True, - is_truncated = False, - primary_phy = 0, - secondary_phy = 0, - tx_power = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE, - sid = 0, - data = b'' + rssi=HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE, + is_legacy=False, + is_anonymous=False, + is_connectable=False, + is_directed=False, + is_scannable=False, + is_scan_response=False, + is_complete=True, + is_truncated=False, + primary_phy=0, + secondary_phy=0, + tx_power=HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE, + sid=0, + data=b'', ): - self.address = address - self.rssi = rssi - self.is_legacy = is_legacy - self.is_anonymous = is_anonymous - self.is_connectable = is_connectable - self.is_directed = is_directed - self.is_scannable = is_scannable + self.address = address + self.rssi = rssi + self.is_legacy = is_legacy + self.is_anonymous = is_anonymous + self.is_connectable = is_connectable + self.is_directed = is_directed + self.is_scannable = is_scannable self.is_scan_response = is_scan_response - self.is_complete = is_complete - self.is_truncated = is_truncated - self.primary_phy = primary_phy - self.secondary_phy = secondary_phy - self.tx_power = tx_power - self.sid = sid - self.data = AdvertisingData.from_bytes(data) + self.is_complete = is_complete + self.is_truncated = is_truncated + self.primary_phy = primary_phy + self.secondary_phy = secondary_phy + self.tx_power = tx_power + self.sid = sid + self.data = AdvertisingData.from_bytes(data) # ----------------------------------------------------------------------------- @@ -130,20 +141,24 @@ class LegacyAdvertisement(Advertisement): @classmethod def from_advertising_report(cls, report): return cls( - address = report.address, - rssi = report.rssi, - is_legacy = True, - is_connectable = report.event_type in { + address=report.address, + rssi=report.rssi, + is_legacy=True, + is_connectable=report.event_type + in { HCI_LE_Advertising_Report_Event.ADV_IND, - HCI_LE_Advertising_Report_Event.ADV_DIRECT_IND + HCI_LE_Advertising_Report_Event.ADV_DIRECT_IND, }, - is_directed = report.event_type == HCI_LE_Advertising_Report_Event.ADV_DIRECT_IND, - is_scannable = report.event_type in { + is_directed=report.event_type + == HCI_LE_Advertising_Report_Event.ADV_DIRECT_IND, + is_scannable=report.event_type + in { HCI_LE_Advertising_Report_Event.ADV_IND, - HCI_LE_Advertising_Report_Event.ADV_SCAN_IND + HCI_LE_Advertising_Report_Event.ADV_SCAN_IND, }, - is_scan_response = report.event_type == HCI_LE_Advertising_Report_Event.SCAN_RSP, - data = report.data + is_scan_response=report.event_type + == HCI_LE_Advertising_Report_Event.SCAN_RSP, + data=report.data, ) @@ -151,6 +166,7 @@ class LegacyAdvertisement(Advertisement): class ExtendedAdvertisement(Advertisement): @classmethod def from_advertising_report(cls, report): + # fmt: off return cls( address = report.address, rssi = report.rssi, @@ -168,32 +184,39 @@ class ExtendedAdvertisement(Advertisement): sid = report.advertising_sid, data = report.data ) + # fmt: on # ----------------------------------------------------------------------------- class AdvertisementDataAccumulator: def __init__(self, passive=False): - self.passive = passive + self.passive = passive self.last_advertisement = None - self.last_data = b'' + self.last_data = b'' def update(self, report): advertisement = Advertisement.from_advertising_report(report) result = None if advertisement.is_scan_response: - if self.last_advertisement is not None and not self.last_advertisement.is_scan_response: + if ( + self.last_advertisement is not None + and not self.last_advertisement.is_scan_response + ): # This is the response to a scannable advertisement - result = Advertisement.from_advertising_report(report) + result = Advertisement.from_advertising_report(report) result.is_connectable = self.last_advertisement.is_connectable - result.is_scannable = True - result.data = AdvertisingData.from_bytes(self.last_data + report.data) + result.is_scannable = True + result.data = AdvertisingData.from_bytes(self.last_data + report.data) self.last_data = b'' else: if ( - self.passive or - (not advertisement.is_scannable) or - (self.last_advertisement is not None and not self.last_advertisement.is_scan_response) + self.passive + or (not advertisement.is_scannable) + or ( + self.last_advertisement is not None + and not self.last_advertisement.is_scan_response + ) ): # Don't wait for a scan response result = Advertisement.from_advertising_report(report) @@ -207,18 +230,20 @@ class AdvertisementDataAccumulator: # ----------------------------------------------------------------------------- class AdvertisingType(IntEnum): + # fmt: off UNDIRECTED_CONNECTABLE_SCANNABLE = 0x00 # Undirected, connectable, scannable DIRECTED_CONNECTABLE_HIGH_DUTY = 0x01 # Directed, connectable, non-scannable UNDIRECTED_SCANNABLE = 0x02 # Undirected, non-connectable, scannable UNDIRECTED = 0x03 # Undirected, non-connectable, non-scannable DIRECTED_CONNECTABLE_LOW_DUTY = 0x04 # Directed, connectable, non-scannable + # fmt: on @property def has_data(self): return self in { AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, AdvertisingType.UNDIRECTED_SCANNABLE, - AdvertisingType.UNDIRECTED + AdvertisingType.UNDIRECTED, } @property @@ -226,28 +251,28 @@ class AdvertisingType(IntEnum): return self in { AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY, - AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY + AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY, } @property def is_scannable(self): return self in { AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, - AdvertisingType.UNDIRECTED_SCANNABLE + AdvertisingType.UNDIRECTED_SCANNABLE, } @property def is_directed(self): return self in { AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY, - AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY + AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY, } # ----------------------------------------------------------------------------- class LePhyOptions: # Coded PHY preference - ANY_CODED_PHY = 0 + ANY_CODED_PHY = 0 PREFER_S_2_CODED_PHY = 1 PREFER_S_8_CODED_PHY = 2 @@ -279,23 +304,31 @@ class Peer: async def discover_service(self, uuid): return await self.gatt_client.discover_service(uuid) - async def discover_services(self, uuids = []): + async def discover_services(self, uuids=[]): return await self.gatt_client.discover_services(uuids) async def discover_included_services(self, service): return await self.gatt_client.discover_included_services(service) - async def discover_characteristics(self, uuids = [], service = None): - return await self.gatt_client.discover_characteristics(uuids = uuids, service = service) + async def discover_characteristics(self, uuids=[], service=None): + return await self.gatt_client.discover_characteristics( + uuids=uuids, service=service + ) - async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None): - return await self.gatt_client.discover_descriptors(characteristic, start_handle, end_handle) + async def discover_descriptors( + self, characteristic=None, start_handle=None, end_handle=None + ): + return await self.gatt_client.discover_descriptors( + characteristic, start_handle, end_handle + ) async def discover_attributes(self): return await self.gatt_client.discover_attributes() async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): - return await self.gatt_client.subscribe(characteristic, subscriber, prefer_notify) + return await self.gatt_client.subscribe( + characteristic, subscriber, prefer_notify + ) async def unsubscribe(self, characteristic, subscriber=None): return await self.gatt_client.unsubscribe(characteristic, subscriber) @@ -312,7 +345,7 @@ class Peer: def get_services_by_uuid(self, uuid): return self.gatt_client.get_services_by_uuid(uuid) - def get_characteristics_by_uuid(self, uuid, service = None): + def get_characteristics_by_uuid(self, uuid, service=None): return self.gatt_client.get_characteristics_by_uuid(uuid, service) def create_service_proxy(self, proxy_class): @@ -352,10 +385,10 @@ class Peer: class ConnectionParametersPreferences: connection_interval_min: int = DEVICE_DEFAULT_CONNECTION_INTERVAL_MIN connection_interval_max: int = DEVICE_DEFAULT_CONNECTION_INTERVAL_MAX - max_latency: int = DEVICE_DEFAULT_CONNECTION_MAX_LATENCY - supervision_timeout: int = DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT - min_ce_length: int = DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH - max_ce_length: int = DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH + max_latency: int = DEVICE_DEFAULT_CONNECTION_MAX_LATENCY + supervision_timeout: int = DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT + min_ce_length: int = DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH + max_ce_length: int = DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH ConnectionParametersPreferences.default = ConnectionParametersPreferences() @@ -399,28 +432,30 @@ class Connection(CompositeEventEmitter): peer_resolvable_address, role, parameters, - phy + phy, ): super().__init__() - self.device = device - self.handle = handle - self.transport = transport - self.self_address = self_address - self.peer_address = peer_address + self.device = device + self.handle = handle + self.transport = transport + self.self_address = self_address + self.peer_address = peer_address self.peer_resolvable_address = peer_resolvable_address - self.peer_name = None # Classic only - self.role = role - self.parameters = parameters - self.encryption = 0 - self.authenticated = False - self.sc = False - self.link_key_type = None - self.authenticating = False - self.phy = phy - self.att_mtu = ATT_DEFAULT_MTU - self.data_length = DEVICE_DEFAULT_DATA_LENGTH - self.gatt_client = None # Per-connection client - self.gatt_server = device.gatt_server # By default, use the device's shared server + self.peer_name = None # Classic only + self.role = role + self.parameters = parameters + self.encryption = 0 + self.authenticated = False + self.sc = False + self.link_key_type = None + self.authenticating = False + self.phy = phy + self.att_mtu = ATT_DEFAULT_MTU + self.data_length = DEVICE_DEFAULT_DATA_LENGTH + self.gatt_client = None # Per-connection client + self.gatt_server = ( + device.gatt_server + ) # By default, use the device's shared server # [Classic only] @classmethod @@ -429,7 +464,17 @@ class Connection(CompositeEventEmitter): Instantiate an incomplete connection (ie. one waiting for a HCI Connection Complete event). Once received it shall be completed using the `.complete` method. """ - return cls(device, None, BT_BR_EDR_TRANSPORT, device.public_address, peer_address, None, None, None, None) + return cls( + device, + None, + BT_BR_EDR_TRANSPORT, + device.public_address, + peer_address, + None, + None, + None, + None, + ) # [Classic only] def complete(self, handle, peer_resolvable_address, role, parameters): @@ -438,10 +483,10 @@ class Connection(CompositeEventEmitter): """ assert self.handle is None assert self.transport == BT_BR_EDR_TRANSPORT - self.handle = handle + self.handle = handle self.peer_resolvable_address = peer_resolvable_address - self.role = role - self.parameters = parameters + self.role = role + self.parameters = parameters @property def role_name(self): @@ -462,7 +507,7 @@ class Connection(CompositeEventEmitter): psm, max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, - mps=DEVICE_DEFAULT_L2CAP_COC_MPS + mps=DEVICE_DEFAULT_L2CAP_COC_MPS, ): return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps) @@ -483,7 +528,7 @@ class Connection(CompositeEventEmitter): return await self.device.encrypt(self) async def sustain(self, timeout=None): - """ Idles the current task waiting for a disconnect or timeout """ + """Idles the current task waiting for a disconnect or timeout""" abort = asyncio.get_running_loop().create_future() self.on('disconnection', abort.set_result) @@ -502,14 +547,14 @@ class Connection(CompositeEventEmitter): connection_interval_min, connection_interval_max, max_latency, - supervision_timeout + supervision_timeout, ): return await self.device.update_connection_parameters( self, connection_interval_min, connection_interval_max, max_latency, - supervision_timeout + supervision_timeout, ) async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None): @@ -545,24 +590,26 @@ class Connection(CompositeEventEmitter): class DeviceConfiguration: def __init__(self): # Setup defaults - self.name = DEVICE_DEFAULT_NAME - self.address = DEVICE_DEFAULT_ADDRESS - self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE - self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA + self.name = DEVICE_DEFAULT_NAME + self.address = DEVICE_DEFAULT_ADDRESS + self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE + self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL - self.le_enabled = True + self.le_enabled = True # LE host enable 2nd parameter - self.le_simultaneous_enabled = True - self.classic_sc_enabled = True - self.classic_ssp_enabled = True - self.classic_accept_any = True - self.connectable = True - self.discoverable = True + self.le_simultaneous_enabled = True + self.classic_sc_enabled = True + self.classic_ssp_enabled = True + self.classic_accept_any = True + self.connectable = True + self.discoverable = True self.advertising_data = bytes( - AdvertisingData([(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]) + AdvertisingData( + [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))] + ) ) - self.irk = bytes(16) # This really must be changed for any level of security + self.irk = bytes(16) # This really must be changed for any level of security self.keystore = None self.gatt_services = [] @@ -571,17 +618,27 @@ class DeviceConfiguration: self.name = config.get('name', self.name) self.address = Address(config.get('address', self.address)) self.class_of_device = config.get('class_of_device', self.class_of_device) - self.advertising_interval_min = config.get('advertising_interval', self.advertising_interval_min) + self.advertising_interval_min = config.get( + 'advertising_interval', self.advertising_interval_min + ) self.advertising_interval_max = self.advertising_interval_min - self.keystore = config.get('keystore') - self.le_enabled = config.get('le_enabled', self.le_enabled) - self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled) - self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled) - self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled) - self.classic_accept_any = config.get('classic_accept_any', self.classic_accept_any) - self.connectable = config.get('connectable', self.connectable) - self.discoverable = config.get('discoverable', self.discoverable) - self.gatt_services = config.get('gatt_services', self.gatt_services) + self.keystore = config.get('keystore') + self.le_enabled = config.get('le_enabled', self.le_enabled) + self.le_simultaneous_enabled = config.get( + 'le_simultaneous_enabled', self.le_simultaneous_enabled + ) + self.classic_sc_enabled = config.get( + 'classic_sc_enabled', self.classic_sc_enabled + ) + self.classic_ssp_enabled = config.get( + 'classic_ssp_enabled', self.classic_ssp_enabled + ) + self.classic_accept_any = config.get( + 'classic_accept_any', self.classic_accept_any + ) + self.connectable = config.get('connectable', self.connectable) + self.discoverable = config.get('discoverable', self.discoverable) + self.gatt_services = config.get('gatt_services', self.gatt_services) # Load or synthesize an IRK irk = config.get('irk') @@ -617,6 +674,7 @@ def with_connection_from_handle(function): if (connection := self.lookup_connection(connection_handle)) is None: raise ValueError(f"no connection for handle: 0x{connection_handle:04x}") return function(self, connection, *args, **kwargs) + return wrapper @@ -624,12 +682,13 @@ def with_connection_from_handle(function): def with_connection_from_address(function): @functools.wraps(function) def wrapper(self, address, *args, **kwargs): - if (connection := self.pending_connections.get(address, False)): + if connection := self.pending_connections.get(address, False): return function(self, connection, *args, **kwargs) for connection in self.connections.values(): if connection.peer_address == address: return function(self, connection, *args, **kwargs) raise ValueError('no connection for address') + return wrapper @@ -637,12 +696,13 @@ def with_connection_from_address(function): def try_with_connection_from_address(function): @functools.wraps(function) def wrapper(self, address, *args, **kwargs): - if (connection := self.pending_connections.get(address, False)): + if connection := self.pending_connections.get(address, False): return function(self, connection, address, *args, **kwargs) for connection in self.connections.values(): if connection.peer_address == address: return function(self, connection, address, *args, **kwargs) return function(self, None, address, *args, **kwargs) + return wrapper @@ -661,7 +721,6 @@ device_host_event_handlers = [] # ----------------------------------------------------------------------------- class Device(CompositeEventEmitter): - @composite_listener class Listener: def on_advertisement(self, advertisement): @@ -679,7 +738,9 @@ class Device(CompositeEventEmitter): def on_connection_request(self, bd_addr, class_of_device, link_type): pass - def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled): + def on_characteristic_subscription( + self, connection, characteristic, notify_enabled, indicate_enabled + ): pass @classmethod @@ -688,8 +749,8 @@ class Device(CompositeEventEmitter): Create a Device instance with a Host configured to communicate with a controller through an HCI source/sink ''' - host = Host(controller_source = hci_source, controller_sink = hci_sink) - return cls(name = name, address = address, host = host) + host = Host(controller_source=hci_source, controller_sink=hci_sink) + return cls(name=name, address=address, host=host) @classmethod def from_config_file(cls, filename): @@ -701,61 +762,70 @@ class Device(CompositeEventEmitter): def from_config_file_with_hci(cls, filename, hci_source, hci_sink): config = DeviceConfiguration() config.load_from_file(filename) - host = Host(controller_source = hci_source, controller_sink = hci_sink) - return cls(config = config, host = host) + host = Host(controller_source=hci_source, controller_sink=hci_sink) + return cls(config=config, host=host) - def __init__(self, name = None, address = None, config = None, host = None, generic_access_service = True): + def __init__( + self, + name=None, + address=None, + config=None, + host=None, + generic_access_service=True, + ): super().__init__() - self._host = None - self.powered_on = False - self.advertising = False - self.advertising_type = None - self.auto_restart_inquiry = True - self.auto_restart_advertising = False - self.command_timeout = 10 # seconds - self.gatt_server = gatt_server.Server(self) - self.sdp_server = sdp.Server(self) - self.l2cap_channel_manager = l2cap.ChannelManager( + self._host = None + self.powered_on = False + self.advertising = False + self.advertising_type = None + self.auto_restart_inquiry = True + self.auto_restart_advertising = False + self.command_timeout = 10 # seconds + self.gatt_server = gatt_server.Server(self) + self.sdp_server = sdp.Server(self) + self.l2cap_channel_manager = l2cap.ChannelManager( [l2cap.L2CAP_Information_Request.EXTENDED_FEATURE_FIXED_CHANNELS] ) self.advertisement_accumulators = {} # Accumulators, by address - self.scanning = False - self.scanning_is_passive = False - self.discovering = False - self.le_connecting = False - self.disconnecting = False - self.connections = {} # Connections, by connection handle - self.pending_connections = {} # Connections, by BD address (BR/EDR only) - self.classic_enabled = False - self.inquiry_response = None - self.address_resolver = None - self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY + self.scanning = False + self.scanning_is_passive = False + self.discovering = False + self.le_connecting = False + self.disconnecting = False + self.connections = {} # Connections, by connection handle + self.pending_connections = {} # Connections, by BD address (BR/EDR only) + self.classic_enabled = False + self.inquiry_response = None + self.address_resolver = None + self.classic_pending_accepts = { + Address.ANY: [] + } # Futures, by BD address OR [Futures] for Address.ANY # Own address type cache self.advertising_own_address_type = None - self.connect_own_address_type = None + self.connect_own_address_type = None # Use the initial config or a default self.public_address = Address('00:00:00:00:00:00') if config is None: config = DeviceConfiguration() - self.name = config.name - self.random_address = config.address - self.class_of_device = config.class_of_device - self.scan_response_data = config.scan_response_data - self.advertising_data = config.advertising_data + self.name = config.name + self.random_address = config.address + self.class_of_device = config.class_of_device + self.scan_response_data = config.scan_response_data + self.advertising_data = config.advertising_data self.advertising_interval_min = config.advertising_interval_min self.advertising_interval_max = config.advertising_interval_max - self.keystore = keys.KeyStore.create_for_device(config) - self.irk = config.irk - self.le_enabled = config.le_enabled - self.le_simultaneous_enabled = config.le_simultaneous_enabled - self.classic_ssp_enabled = config.classic_ssp_enabled - self.classic_sc_enabled = config.classic_sc_enabled - self.discoverable = config.discoverable - self.connectable = config.connectable - self.classic_accept_any = config.classic_accept_any + self.keystore = keys.KeyStore.create_for_device(config) + self.irk = config.irk + self.le_enabled = config.le_enabled + self.le_simultaneous_enabled = config.le_simultaneous_enabled + self.classic_ssp_enabled = config.classic_ssp_enabled + self.classic_sc_enabled = config.classic_sc_enabled + self.discoverable = config.discoverable + self.connectable = config.connectable + self.classic_accept_any = config.classic_accept_any for service in config.gatt_services: characteristics = [] @@ -789,10 +859,10 @@ class Device(CompositeEventEmitter): # Setup SMP self.smp_manager = smp.Manager(self) + self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu) self.l2cap_channel_manager.register_fixed_channel( - smp.SMP_CID, self.on_smp_pdu) - self.l2cap_channel_manager.register_fixed_channel( - smp.SMP_BR_CID, self.on_smp_pdu) + smp.SMP_BR_CID, self.on_smp_pdu + ) # Register the SDP server with the L2CAP Channel Manager self.sdp_server.register(self.l2cap_channel_manager) @@ -815,7 +885,9 @@ class Device(CompositeEventEmitter): # Unsubscribe from events from the current host if self._host: for event_name in device_host_event_handlers: - self._host.remove_listener(event_name, getattr(self, f'on_{event_name}')) + self._host.remove_listener( + event_name, getattr(self, f'on_{event_name}') + ) # Subscribe to events from the new host if host: @@ -823,13 +895,13 @@ class Device(CompositeEventEmitter): host.on(event_name, getattr(self, f'on_{event_name}')) # Update the references to the new host - self._host = host + self._host = host self.l2cap_channel_manager.host = host # Set providers for the new host if host: host.long_term_key_provider = self.get_long_term_key - host.link_key_provider = self.get_link_key + host.link_key_provider = self.get_link_key @property def sdp_service_records(self): @@ -843,10 +915,15 @@ class Device(CompositeEventEmitter): if connection := self.connections.get(connection_handle): return connection - def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False): + def find_connection_by_bd_addr( + self, bd_addr, transport=None, check_address_type=False + ): for connection in self.connections.values(): if connection.peer_address.to_bytes() == bd_addr.to_bytes(): - if check_address_type and connection.peer_address.address_type != bd_addr.address_type: + if ( + check_address_type + and connection.peer_address.address_type != bd_addr.address_type + ): continue if transport is None or connection.transport == transport: return connection @@ -866,9 +943,11 @@ class Device(CompositeEventEmitter): server, max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, - mps=DEVICE_DEFAULT_L2CAP_COC_MPS + mps=DEVICE_DEFAULT_L2CAP_COC_MPS, ): - return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps) + return self.l2cap_channel_manager.register_le_coc_server( + psm, server, max_credits, mtu, mps + ) async def open_l2cap_channel( self, @@ -876,9 +955,11 @@ class Device(CompositeEventEmitter): psm, max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, - mps=DEVICE_DEFAULT_L2CAP_COC_MPS + mps=DEVICE_DEFAULT_L2CAP_COC_MPS, ): - return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps) + return await self.l2cap_channel_manager.open_le_coc( + connection, psm, max_credits, mtu, mps + ) def send_l2cap_pdu(self, connection_handle, cid, pdu): self.host.send_l2cap_pdu(connection_handle, cid, pdu) @@ -886,8 +967,7 @@ class Device(CompositeEventEmitter): async def send_command(self, command, check_result=False): try: return await asyncio.wait_for( - self.host.send_command(command, check_result), - self.command_timeout + self.host.send_command(command, check_result), self.command_timeout ) except asyncio.TimeoutError: logger.warning('!!! Command timed out') @@ -899,33 +979,40 @@ class Device(CompositeEventEmitter): response = await self.send_command(HCI_Read_BD_ADDR_Command()) if response.return_parameters.status == HCI_SUCCESS: - logger.debug(color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow')) + logger.debug( + color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow') + ) self.public_address = response.return_parameters.bd_addr if self.host.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND): - await self.send_command(HCI_Write_LE_Host_Support_Command( - le_supported_host = int(self.le_enabled), - simultaneous_le_host = int(self.le_simultaneous_enabled), - )) + await self.send_command( + HCI_Write_LE_Host_Support_Command( + le_supported_host=int(self.le_enabled), + simultaneous_le_host=int(self.le_simultaneous_enabled), + ) + ) if self.le_enabled: # Set the controller address - await self.send_command(HCI_LE_Set_Random_Address_Command( - random_address = self.random_address - ), check_result=True) + await self.send_command( + HCI_LE_Set_Random_Address_Command(random_address=self.random_address), + check_result=True, + ) # Load the address resolving list - if self.keystore and self.host.supports_command(HCI_LE_CLEAR_RESOLVING_LIST_COMMAND): + if self.keystore and self.host.supports_command( + HCI_LE_CLEAR_RESOLVING_LIST_COMMAND + ): await self.send_command(HCI_LE_Clear_Resolving_List_Command()) resolving_keys = await self.keystore.get_resolving_keys() for (irk, address) in resolving_keys: await self.send_command( HCI_LE_Add_Device_To_Resolving_List_Command( - peer_identity_address_type = address.address_type, - peer_identity_address = address, - peer_irk = irk, - local_irk = self.irk + peer_identity_address_type=address.address_type, + peer_identity_address=address, + peer_irk=irk, + local_irk=self.irk, ) ) @@ -942,15 +1029,17 @@ class Device(CompositeEventEmitter): HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) ) await self.send_command( - HCI_Write_Class_Of_Device_Command(class_of_device = self.class_of_device) + HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) ) await self.send_command( HCI_Write_Simple_Pairing_Mode_Command( - simple_pairing_mode=int(self.classic_ssp_enabled)) + simple_pairing_mode=int(self.classic_ssp_enabled) + ) ) await self.send_command( HCI_Write_Secure_Connections_Host_Support_Command( - secure_connections_host_support=int(self.classic_sc_enabled)) + secure_connections_host_support=int(self.classic_sc_enabled) + ) ) await self.set_connectable(self.connectable) await self.set_discoverable(self.discoverable) @@ -970,8 +1059,8 @@ class Device(CompositeEventEmitter): return True feature_map = { - HCI_LE_2M_PHY: HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE, - HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE + HCI_LE_2M_PHY: HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE, + HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE, } if phy not in feature_map: raise ValueError('invalid PHY') @@ -983,7 +1072,7 @@ class Device(CompositeEventEmitter): advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, target=None, own_address_type=OwnAddressType.RANDOM, - auto_restart=False + auto_restart=False, ): # If we're advertising, stop first if self.advertising: @@ -991,60 +1080,71 @@ class Device(CompositeEventEmitter): # Set/update the advertising data if the advertising type allows it if advertising_type.has_data: - await self.send_command(HCI_LE_Set_Advertising_Data_Command( - advertising_data = self.advertising_data - ), check_result=True) + await self.send_command( + HCI_LE_Set_Advertising_Data_Command( + advertising_data=self.advertising_data + ), + check_result=True, + ) # Set/update the scan response data if the advertising is scannable if advertising_type.is_scannable: - await self.send_command(HCI_LE_Set_Scan_Response_Data_Command( - scan_response_data = self.scan_response_data - ), check_result=True) + await self.send_command( + HCI_LE_Set_Scan_Response_Data_Command( + scan_response_data=self.scan_response_data + ), + check_result=True, + ) # Decide what peer address to use if advertising_type.is_directed: if target is None: raise ValueError('directed advertising requires a target address') - peer_address = target + peer_address = target peer_address_type = target.address_type else: - peer_address = Address('00:00:00:00:00:00') + peer_address = Address('00:00:00:00:00:00') peer_address_type = Address.PUBLIC_DEVICE_ADDRESS # Set the advertising parameters - await self.send_command(HCI_LE_Set_Advertising_Parameters_Command( - advertising_interval_min = self.advertising_interval_min, - advertising_interval_max = self.advertising_interval_max, - advertising_type = int(advertising_type), - own_address_type = own_address_type, - peer_address_type = peer_address_type, - peer_address = peer_address, - advertising_channel_map = 7, - advertising_filter_policy = 0 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Advertising_Parameters_Command( + advertising_interval_min=self.advertising_interval_min, + advertising_interval_max=self.advertising_interval_max, + advertising_type=int(advertising_type), + own_address_type=own_address_type, + peer_address_type=peer_address_type, + peer_address=peer_address, + advertising_channel_map=7, + advertising_filter_policy=0, + ), + check_result=True, + ) # Enable advertising - await self.send_command(HCI_LE_Set_Advertising_Enable_Command( - advertising_enable = 1 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), + check_result=True, + ) self.advertising_own_address_type = own_address_type - self.auto_restart_advertising = auto_restart - self.advertising_type = advertising_type - self.advertising = True + self.auto_restart_advertising = auto_restart + self.advertising_type = advertising_type + self.advertising = True async def stop_advertising(self): # Disable advertising if self.advertising: - await self.send_command(HCI_LE_Set_Advertising_Enable_Command( - advertising_enable = 0 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), + check_result=True, + ) self.advertising_own_address_type = None - self.advertising = False - self.advertising_type = None - self.auto_restart_advertising = False + self.advertising = False + self.advertising_type = None + self.auto_restart_advertising = False @property def is_advertising(self): @@ -1055,15 +1155,18 @@ class Device(CompositeEventEmitter): legacy=False, active=True, scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms - scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms + scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms own_address_type=OwnAddressType.RANDOM, filter_duplicates=False, - scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY) + scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY), ): # Check that the arguments are legal if scan_interval < scan_window: raise ValueError('scan_interval must be >= scan_window') - if scan_interval < DEVICE_MIN_SCAN_INTERVAL or scan_interval > DEVICE_MAX_SCAN_INTERVAL: + if ( + scan_interval < DEVICE_MIN_SCAN_INTERVAL + or scan_interval > DEVICE_MAX_SCAN_INTERVAL + ): raise ValueError('scan_interval out of range') if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW: raise ValueError('scan_interval out of range') @@ -1072,10 +1175,18 @@ class Device(CompositeEventEmitter): self.advertisement_accumulator = {} # Enable scanning - if not legacy and self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE): + if not legacy and self.supports_le_feature( + HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE + ): # Set the scanning parameters - scan_type = HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING if active else HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING - scanning_filter_policy = HCI_LE_Set_Extended_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY # TODO: support other types + scan_type = ( + HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING + if active + else HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING + ) + scanning_filter_policy = ( + HCI_LE_Set_Extended_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY + ) # TODO: support other types scanning_phy_count = 0 scanning_phys_bits = 0 @@ -1090,56 +1201,71 @@ class Device(CompositeEventEmitter): if scanning_phy_count == 0: raise ValueError('at least one scanning PHY must be enabled') - await self.send_command(HCI_LE_Set_Extended_Scan_Parameters_Command( - own_address_type = own_address_type, - scanning_filter_policy = scanning_filter_policy, - scanning_phys = scanning_phys_bits, - scan_types = [scan_type] * scanning_phy_count, - scan_intervals = [int(scan_window / 0.625)] * scanning_phy_count, - scan_windows = [int(scan_window / 0.625)] * scanning_phy_count - ), check_result=True) + await self.send_command( + HCI_LE_Set_Extended_Scan_Parameters_Command( + own_address_type=own_address_type, + scanning_filter_policy=scanning_filter_policy, + scanning_phys=scanning_phys_bits, + scan_types=[scan_type] * scanning_phy_count, + scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count, + scan_windows=[int(scan_window / 0.625)] * scanning_phy_count, + ), + check_result=True, + ) # Enable scanning - await self.send_command(HCI_LE_Set_Extended_Scan_Enable_Command( - enable = 1, - filter_duplicates = 1 if filter_duplicates else 0, - duration = 0, # TODO allow other values - period = 0 # TODO allow other values - ), check_result=True) + await self.send_command( + HCI_LE_Set_Extended_Scan_Enable_Command( + enable=1, + filter_duplicates=1 if filter_duplicates else 0, + duration=0, # TODO allow other values + period=0, # TODO allow other values + ), + check_result=True, + ) else: # Set the scanning parameters - scan_type = HCI_LE_Set_Scan_Parameters_Command.ACTIVE_SCANNING if active else HCI_LE_Set_Scan_Parameters_Command.PASSIVE_SCANNING - await self.send_command(HCI_LE_Set_Scan_Parameters_Command( - le_scan_type = scan_type, - le_scan_interval = int(scan_window / 0.625), - le_scan_window = int(scan_window / 0.625), - own_address_type = own_address_type, - scanning_filter_policy = HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY - ), check_result=True) + scan_type = ( + HCI_LE_Set_Scan_Parameters_Command.ACTIVE_SCANNING + if active + else HCI_LE_Set_Scan_Parameters_Command.PASSIVE_SCANNING + ) + await self.send_command( + HCI_LE_Set_Scan_Parameters_Command( + le_scan_type=scan_type, + le_scan_interval=int(scan_window / 0.625), + le_scan_window=int(scan_window / 0.625), + own_address_type=own_address_type, + scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY, + ), + check_result=True, + ) # Enable scanning - await self.send_command(HCI_LE_Set_Scan_Enable_Command( - le_scan_enable = 1, - filter_duplicates = 1 if filter_duplicates else 0 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Scan_Enable_Command( + le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0 + ), + check_result=True, + ) self.scanning_is_passive = not active - self.scanning = True + self.scanning = True async def stop_scanning(self): # Disable scanning if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE): - await self.send_command(HCI_LE_Set_Extended_Scan_Enable_Command( - enable = 0, - filter_duplicates = 0, - duration = 0, - period = 0 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Extended_Scan_Enable_Command( + enable=0, filter_duplicates=0, duration=0, period=0 + ), + check_result=True, + ) else: - await self.send_command(HCI_LE_Set_Scan_Enable_Command( - le_scan_enable = 0, - filter_duplicates = 0 - ), check_result=True) + await self.send_command( + HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), + check_result=True, + ) self.scanning = False @@ -1156,27 +1282,30 @@ class Device(CompositeEventEmitter): self.emit('advertisement', advertisement) async def start_discovery(self, auto_restart=True): - await self.send_command(HCI_Write_Inquiry_Mode_Command( - inquiry_mode=HCI_EXTENDED_INQUIRY_MODE - ), check_result=True) + await self.send_command( + HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), + check_result=True, + ) - response = await self.send_command(HCI_Inquiry_Command( - lap = HCI_GENERAL_INQUIRY_LAP, - inquiry_length = DEVICE_DEFAULT_INQUIRY_LENGTH, - num_responses = 0 # Unlimited number of responses. - )) + response = await self.send_command( + HCI_Inquiry_Command( + lap=HCI_GENERAL_INQUIRY_LAP, + inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH, + num_responses=0, # Unlimited number of responses. + ) + ) if response.status != HCI_Command_Status_Event.PENDING: self.discovering = False raise HCI_StatusError(response) self.auto_restart_inquiry = auto_restart - self.discovering = True + self.discovering = True async def stop_discovery(self): if self.discovering: await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) self.auto_restart_inquiry = True - self.discovering = False + self.discovering = False @host_event_handler def on_inquiry_result(self, address, class_of_device, data, rssi): @@ -1185,7 +1314,7 @@ class Device(CompositeEventEmitter): address, class_of_device, AdvertisingData.from_bytes(data), - rssi + rssi, ) async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled): @@ -1198,7 +1327,9 @@ class Device(CompositeEventEmitter): else: scan_enable = 0x00 - return await self.send_command(HCI_Write_Scan_Enable_Command(scan_enable = scan_enable)) + return await self.send_command( + HCI_Write_Scan_Enable_Command(scan_enable=scan_enable) + ) async def set_discoverable(self, discoverable=True): self.discoverable = discoverable @@ -1206,30 +1337,34 @@ class Device(CompositeEventEmitter): # Synthesize an inquiry response if none is set already if self.inquiry_response is None: self.inquiry_response = bytes( - AdvertisingData([ - (AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8')) - ]) + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(self.name, 'utf-8'), + ) + ] + ) ) # Update the controller await self.send_command( HCI_Write_Extended_Inquiry_Response_Command( - fec_required = 0, - extended_inquiry_response = self.inquiry_response + fec_required=0, extended_inquiry_response=self.inquiry_response ), - check_result=True + check_result=True, ) await self.set_scan_enable( - inquiry_scan_enabled = self.discoverable, - page_scan_enabled = self.connectable + inquiry_scan_enabled=self.discoverable, + page_scan_enabled=self.connectable, ) async def set_connectable(self, connectable=True): self.connectable = connectable if self.classic_enabled: await self.set_scan_enable( - inquiry_scan_enabled = self.discoverable, - page_scan_enabled = self.connectable + inquiry_scan_enabled=self.discoverable, + page_scan_enabled=self.connectable, ) async def connect( @@ -1238,7 +1373,7 @@ class Device(CompositeEventEmitter): transport=BT_LE_TRANSPORT, connection_parameters_preferences=None, own_address_type=OwnAddressType.RANDOM, - timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT + timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT, ): ''' Request a connection to a peer. @@ -1267,27 +1402,36 @@ class Device(CompositeEventEmitter): if type(peer_address) is str: try: - peer_address = Address.from_string_for_transport(peer_address, transport) + peer_address = Address.from_string_for_transport( + peer_address, transport + ) except ValueError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') - peer_address = await self.find_peer_by_name(peer_address, transport) # TODO: timeout + peer_address = await self.find_peer_by_name( + peer_address, transport + ) # TODO: timeout else: # All BR/EDR addresses should be public addresses - if transport == BT_BR_EDR_TRANSPORT and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS: + if ( + transport == BT_BR_EDR_TRANSPORT + and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS + ): raise ValueError('BR/EDR addresses must be PUBLIC') def on_connection(connection): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection event against peer address - connection.transport == transport and connection.peer_address == peer_address + connection.transport == transport + and connection.peer_address == peer_address ): pending_connection.set_result(connection) def on_connection_failure(error): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection failure event against peer address - error.transport == transport and error.peer_address == peer_address + error.transport == transport + and error.peer_address == peer_address ): pending_connection.set_exception(error) @@ -1302,16 +1446,27 @@ class Device(CompositeEventEmitter): if connection_parameters_preferences is None: if connection_parameters_preferences is None: connection_parameters_preferences = { - HCI_LE_1M_PHY: ConnectionParametersPreferences.default, - HCI_LE_2M_PHY: ConnectionParametersPreferences.default, - HCI_LE_CODED_PHY: ConnectionParametersPreferences.default + HCI_LE_1M_PHY: ConnectionParametersPreferences.default, + HCI_LE_2M_PHY: ConnectionParametersPreferences.default, + HCI_LE_CODED_PHY: ConnectionParametersPreferences.default, } self.connect_own_address_type = own_address_type - if self.host.supports_command(HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND): + if self.host.supports_command( + HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND + ): # Only keep supported PHYs - phys = sorted(list(set(filter(self.supports_le_phy, connection_parameters_preferences.keys())))) + phys = sorted( + list( + set( + filter( + self.supports_le_phy, + connection_parameters_preferences.keys(), + ) + ) + ) + ) if not phys: raise ValueError('least one supported PHY needed') @@ -1319,71 +1474,116 @@ class Device(CompositeEventEmitter): initiating_phys = phy_list_to_bits(phys) connection_interval_mins = [ - int(connection_parameters_preferences[phy].connection_interval_min / 1.25) for phy in phys + int( + connection_parameters_preferences[ + phy + ].connection_interval_min + / 1.25 + ) + for phy in phys ] connection_interval_maxs = [ - int(connection_parameters_preferences[phy].connection_interval_max / 1.25) for phy in phys + int( + connection_parameters_preferences[ + phy + ].connection_interval_max + / 1.25 + ) + for phy in phys ] max_latencies = [ - connection_parameters_preferences[phy].max_latency for phy in phys + connection_parameters_preferences[phy].max_latency + for phy in phys ] supervision_timeouts = [ - int(connection_parameters_preferences[phy].supervision_timeout / 10) for phy in phys + int( + connection_parameters_preferences[phy].supervision_timeout + / 10 + ) + for phy in phys ] min_ce_lengths = [ - int(connection_parameters_preferences[phy].min_ce_length / 0.625) for phy in phys + int( + connection_parameters_preferences[phy].min_ce_length / 0.625 + ) + for phy in phys ] max_ce_lengths = [ - int(connection_parameters_preferences[phy].max_ce_length / 0.625) for phy in phys + int( + connection_parameters_preferences[phy].max_ce_length / 0.625 + ) + for phy in phys ] - result = await self.send_command(HCI_LE_Extended_Create_Connection_Command( - initiator_filter_policy = 0, - own_address_type = own_address_type, - peer_address_type = peer_address.address_type, - peer_address = peer_address, - initiating_phys = initiating_phys, - scan_intervals = (int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625),) * phy_count, - scan_windows = (int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),) * phy_count, - connection_interval_mins = connection_interval_mins, - connection_interval_maxs = connection_interval_maxs, - max_latencies = max_latencies, - supervision_timeouts = supervision_timeouts, - min_ce_lengths = min_ce_lengths, - max_ce_lengths = max_ce_lengths - )) + result = await self.send_command( + HCI_LE_Extended_Create_Connection_Command( + initiator_filter_policy=0, + own_address_type=own_address_type, + peer_address_type=peer_address.address_type, + peer_address=peer_address, + initiating_phys=initiating_phys, + scan_intervals=( + int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625), + ) + * phy_count, + scan_windows=( + int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625), + ) + * phy_count, + connection_interval_mins=connection_interval_mins, + connection_interval_maxs=connection_interval_maxs, + max_latencies=max_latencies, + supervision_timeouts=supervision_timeouts, + min_ce_lengths=min_ce_lengths, + max_ce_lengths=max_ce_lengths, + ) + ) else: if HCI_LE_1M_PHY not in connection_parameters_preferences: raise ValueError('1M PHY preferences required') prefs = connection_parameters_preferences[HCI_LE_1M_PHY] - result = await self.send_command(HCI_LE_Create_Connection_Command( - le_scan_interval = int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625), - le_scan_window = int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625), - initiator_filter_policy = 0, - peer_address_type = peer_address.address_type, - peer_address = peer_address, - own_address_type = own_address_type, - connection_interval_min = int(prefs.connection_interval_min / 1.25), - connection_interval_max = int(prefs.connection_interval_max / 1.25), - max_latency = prefs.max_latency, - supervision_timeout = int(prefs.supervision_timeout / 10), - min_ce_length = int(prefs.min_ce_length / 0.625), - max_ce_length = int(prefs.max_ce_length / 0.625), - )) + result = await self.send_command( + HCI_LE_Create_Connection_Command( + le_scan_interval=int( + DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625 + ), + le_scan_window=int( + DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625 + ), + initiator_filter_policy=0, + peer_address_type=peer_address.address_type, + peer_address=peer_address, + own_address_type=own_address_type, + connection_interval_min=int( + prefs.connection_interval_min / 1.25 + ), + connection_interval_max=int( + prefs.connection_interval_max / 1.25 + ), + max_latency=prefs.max_latency, + supervision_timeout=int(prefs.supervision_timeout / 10), + min_ce_length=int(prefs.min_ce_length / 0.625), + max_ce_length=int(prefs.max_ce_length / 0.625), + ) + ) else: # Save pending connection - self.pending_connections[peer_address] = Connection.incomplete(self, peer_address) + self.pending_connections[peer_address] = Connection.incomplete( + self, peer_address + ) # TODO: allow passing other settings - result = await self.send_command(HCI_Create_Connection_Command( - bd_addr = peer_address, - packet_type = 0xCC18, # FIXME: change - page_scan_repetition_mode = HCI_R2_PAGE_SCAN_REPETITION_MODE, - clock_offset = 0x0000, - allow_role_switch = 0x01, - reserved = 0 - )) + result = await self.send_command( + HCI_Create_Connection_Command( + bd_addr=peer_address, + packet_type=0xCC18, # FIXME: change + page_scan_repetition_mode=HCI_R2_PAGE_SCAN_REPETITION_MODE, + clock_offset=0x0000, + allow_role_switch=0x01, + reserved=0, + ) + ) if result.status != HCI_Command_Status_Event.PENDING: raise HCI_StatusError(result) @@ -1395,12 +1595,18 @@ class Device(CompositeEventEmitter): return await pending_connection else: try: - return await asyncio.wait_for(asyncio.shield(pending_connection), timeout) + return await asyncio.wait_for( + asyncio.shield(pending_connection), timeout + ) except asyncio.TimeoutError: if transport == BT_LE_TRANSPORT: - await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) + await self.send_command( + HCI_LE_Create_Connection_Cancel_Command() + ) else: - await self.send_command(HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)) + await self.send_command( + HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) + ) try: return await pending_connection @@ -1419,7 +1625,7 @@ class Device(CompositeEventEmitter): self, peer_address=Address.ANY, role=BT_PERIPHERAL_ROLE, - timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT + timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT, ): ''' Wait and accept any incoming connection or a connection from `peer_address` when set. @@ -1436,7 +1642,9 @@ class Device(CompositeEventEmitter): except ValueError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') - peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout + peer_address = await self.find_peer_by_name( + peer_address, BT_BR_EDR_TRANSPORT + ) # TODO: timeout if peer_address == Address.NIL: raise ValueError('accept on nil address') @@ -1453,7 +1661,11 @@ class Device(CompositeEventEmitter): try: # Wait for a request or a completed connection - result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request) + result = await ( + asyncio.wait_for(pending_request, timeout) + if timeout + else pending_request + ) except Exception: # Remove future from device context if peer_address == Address.ANY: @@ -1471,11 +1683,17 @@ class Device(CompositeEventEmitter): peer_address, class_of_device, link_type = result def on_connection(connection): - if connection.transport == BT_BR_EDR_TRANSPORT and connection.peer_address == peer_address: + if ( + connection.transport == BT_BR_EDR_TRANSPORT + and connection.peer_address == peer_address + ): pending_connection.set_result(connection) def on_connection_failure(error): - if error.transport == BT_BR_EDR_TRANSPORT and error.peer_address == peer_address: + if ( + error.transport == BT_BR_EDR_TRANSPORT + and error.peer_address == peer_address + ): pending_connection.set_exception(error) # Create a future so that we can wait for the connection's result @@ -1484,14 +1702,15 @@ class Device(CompositeEventEmitter): self.on('connection_failure', on_connection_failure) # Save pending connection - self.pending_connections[peer_address] = Connection.incomplete(self, peer_address) + self.pending_connections[peer_address] = Connection.incomplete( + self, peer_address + ) try: # Accept connection request - await self.send_command(HCI_Accept_Connection_Request_Command( - bd_addr = peer_address, - role = role - )) + await self.send_command( + HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) + ) # Wait for connection complete return await pending_connection @@ -1504,7 +1723,9 @@ class Device(CompositeEventEmitter): @asynccontextmanager async def connect_as_gatt(self, peer_address): async with AsyncExitStack() as stack: - connection = await stack.enter_async_context(await self.connect(peer_address)) + connection = await stack.enter_async_context( + await self.connect(peer_address) + ) peer = await stack.enter_async_context(Peer(connection)) yield peer @@ -1522,7 +1743,9 @@ class Device(CompositeEventEmitter): if peer_address is None: if not self.is_le_connecting: return - await self.send_command(HCI_LE_Create_Connection_Cancel_Command(), check_result=True) + await self.send_command( + HCI_LE_Create_Connection_Cancel_Command(), check_result=True + ) # BR/EDR: try to cancel to ongoing connection # NOTE: This API does not prevent from trying to cancel a connection which is not currently being created @@ -1533,9 +1756,14 @@ class Device(CompositeEventEmitter): except ValueError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') - peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout + peer_address = await self.find_peer_by_name( + peer_address, BT_BR_EDR_TRANSPORT + ) # TODO: timeout - await self.send_command(HCI_Create_Connection_Cancel_Command(bd_addr=peer_address), check_result=True) + await self.send_command( + HCI_Create_Connection_Cancel_Command(bd_addr=peer_address), + check_result=True, + ) async def disconnect(self, connection, reason): # Create a future so that we can wait for the disconnection's result @@ -1544,9 +1772,9 @@ class Device(CompositeEventEmitter): connection.on('disconnection_failure', pending_disconnection.set_exception) # Request a disconnection - result = await self.send_command(HCI_Disconnect_Command( - connection_handle = connection.handle, reason = reason - )) + result = await self.send_command( + HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) + ) try: if result.status != HCI_Command_Status_Event.PENDING: @@ -1556,8 +1784,12 @@ class Device(CompositeEventEmitter): self.disconnecting = True return await pending_disconnection finally: - connection.remove_listener('disconnection', pending_disconnection.set_result) - connection.remove_listener('disconnection_failure', pending_disconnection.set_exception) + connection.remove_listener( + 'disconnection', pending_disconnection.set_result + ) + connection.remove_listener( + 'disconnection_failure', pending_disconnection.set_exception + ) self.disconnecting = False async def update_connection_parameters( @@ -1567,61 +1799,68 @@ class Device(CompositeEventEmitter): connection_interval_max, max_latency, supervision_timeout, - min_ce_length = 0, - max_ce_length = 0 + min_ce_length=0, + max_ce_length=0, ): ''' NOTE: the name of the parameters may look odd, but it just follows the names used in the Bluetooth spec. ''' - await self.send_command(HCI_LE_Connection_Update_Command( - connection_handle = connection.handle, - connection_interval_min = connection_interval_min, - connection_interval_max = connection_interval_max, - max_latency = max_latency, - supervision_timeout = supervision_timeout, - min_ce_length = min_ce_length, - max_ce_length = max_ce_length - ), check_result=True) + await self.send_command( + HCI_LE_Connection_Update_Command( + connection_handle=connection.handle, + connection_interval_min=connection_interval_min, + connection_interval_max=connection_interval_max, + max_latency=max_latency, + supervision_timeout=supervision_timeout, + min_ce_length=min_ce_length, + max_ce_length=max_ce_length, + ), + check_result=True, + ) async def get_connection_rssi(self, connection): - result = await self.send_command(HCI_Read_RSSI_Command(handle = connection.handle), check_result=True) + result = await self.send_command( + HCI_Read_RSSI_Command(handle=connection.handle), check_result=True + ) return result.return_parameters.rssi async def get_connection_phy(self, connection): result = await self.send_command( - HCI_LE_Read_PHY_Command(connection_handle = connection.handle), - check_result=True + HCI_LE_Read_PHY_Command(connection_handle=connection.handle), + check_result=True, ) return (result.return_parameters.tx_phy, result.return_parameters.rx_phy) async def set_connection_phy( - self, - connection, - tx_phys=None, - rx_phys=None, - phy_options=None + self, connection, tx_phys=None, rx_phys=None, phy_options=None ): - all_phys_bits = (1 if tx_phys is None else 0) | ((1 if rx_phys is None else 0) << 1) + all_phys_bits = (1 if tx_phys is None else 0) | ( + (1 if rx_phys is None else 0) << 1 + ) return await self.send_command( HCI_LE_Set_PHY_Command( - connection_handle = connection.handle, - all_phys = all_phys_bits, - tx_phys = phy_list_to_bits(tx_phys), - rx_phys = phy_list_to_bits(rx_phys), - phy_options = 0 if phy_options is None else int(phy_options) - ), check_result=True + connection_handle=connection.handle, + all_phys=all_phys_bits, + tx_phys=phy_list_to_bits(tx_phys), + rx_phys=phy_list_to_bits(rx_phys), + phy_options=0 if phy_options is None else int(phy_options), + ), + check_result=True, ) async def set_default_phy(self, tx_phys=None, rx_phys=None): - all_phys_bits = (1 if tx_phys is None else 0) | ((1 if rx_phys is None else 0) << 1) + all_phys_bits = (1 if tx_phys is None else 0) | ( + (1 if rx_phys is None else 0) << 1 + ) return await self.send_command( HCI_LE_Set_Default_PHY_Command( - all_phys = all_phys_bits, - tx_phys = phy_list_to_bits(tx_phys), - rx_phys = phy_list_to_bits(rx_phys) - ), check_result=True + all_phys=all_phys_bits, + tx_phys=phy_list_to_bits(tx_phys), + rx_phys=phy_list_to_bits(rx_phys), + ), + check_result=True, ) async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT): @@ -1640,13 +1879,16 @@ class Device(CompositeEventEmitter): if local_name is not None: if local_name.decode('utf-8') == name: peer_address.set_result(address) + try: handler = None if transport == BT_LE_TRANSPORT: event_name = 'advertisement' handler = self.on( event_name, - lambda advertisement: on_peer_found(advertisement.address, advertisement.data) + lambda advertisement: on_peer_found( + advertisement.address, advertisement.data + ), ) was_scanning = self.scanning @@ -1657,8 +1899,9 @@ class Device(CompositeEventEmitter): event_name = 'inquiry_result' handler = self.on( event_name, - lambda address, class_of_device, eir_data, rssi: - on_peer_found(address, eir_data) + lambda address, class_of_device, eir_data, rssi: on_peer_found( + address, eir_data + ), ) was_discovering = self.discovering @@ -1732,15 +1975,19 @@ class Device(CompositeEventEmitter): pending_authentication.set_exception(HCI_Error(error_code)) connection.on('connection_authentication', on_authentication) - connection.on('connection_authentication_failure', on_authentication_failure) + connection.on('connection_authentication_failure', on_authentication_failure) # Request the authentication try: result = await self.send_command( - HCI_Authentication_Requested_Command(connection_handle = connection.handle) + HCI_Authentication_Requested_Command( + connection_handle=connection.handle + ) ) if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}') + logger.warn( + f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}' + ) raise HCI_StatusError(result) # Save in connection we are trying to authenticate @@ -1751,7 +1998,9 @@ class Device(CompositeEventEmitter): finally: connection.authenticating = False connection.remove_listener('connection_authentication', on_authentication) - connection.remove_listener('connection_authentication_failure', on_authentication_failure) + connection.remove_listener( + 'connection_authentication_failure', on_authentication_failure + ) async def encrypt(self, connection): # Set up event handlers @@ -1763,7 +2012,7 @@ class Device(CompositeEventEmitter): def on_encryption_failure(error_code): pending_encryption.set_exception(HCI_Error(error_code)) - connection.on('connection_encryption_change', on_encryption_change) + connection.on('connection_encryption_change', on_encryption_change) connection.on('connection_encryption_failure', on_encryption_failure) # Request the encryption @@ -1778,11 +2027,11 @@ class Device(CompositeEventEmitter): raise RuntimeError('keys not found in key store') if keys.ltk is not None: - ltk = keys.ltk.value + ltk = keys.ltk.value rand = bytes(8) ediv = 0 elif keys.ltk_central is not None: - ltk = keys.ltk_central.value + ltk = keys.ltk_central.value rand = keys.ltk_central.rand ediv = keys.ltk_central.ediv else: @@ -1793,33 +2042,40 @@ class Device(CompositeEventEmitter): result = await self.send_command( HCI_LE_Enable_Encryption_Command( - connection_handle = connection.handle, - random_number = rand, - encrypted_diversifier = ediv, - long_term_key = ltk + connection_handle=connection.handle, + random_number=rand, + encrypted_diversifier=ediv, + long_term_key=ltk, ) ) if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warn(f'HCI_LE_Enable_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') + logger.warn( + f'HCI_LE_Enable_Encryption_Command failed: {HCI_Constant.error_name(result.status)}' + ) raise HCI_StatusError(result) else: result = await self.send_command( HCI_Set_Connection_Encryption_Command( - connection_handle = connection.handle, - encryption_enable = 0x01 + connection_handle=connection.handle, encryption_enable=0x01 ) ) if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warn(f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') + logger.warn( + f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}' + ) raise HCI_StatusError(result) # Wait for the result await pending_encryption finally: - connection.remove_listener('connection_encryption_change', on_encryption_change) - connection.remove_listener('connection_encryption_failure', on_encryption_failure) + connection.remove_listener( + 'connection_encryption_change', on_encryption_change + ) + connection.remove_listener( + 'connection_encryption_failure', on_encryption_failure + ) # [Classic only] async def request_remote_name(self, remote): # remote: Connection | Address @@ -1830,27 +2086,33 @@ class Device(CompositeEventEmitter): handler = self.on( 'remote_name', - lambda address, remote_name: - pending_name.set_result(remote_name) if address == peer_address else None + lambda address, remote_name: pending_name.set_result(remote_name) + if address == peer_address + else None, ) failure_handler = self.on( 'remote_name_failure', - lambda address, error_code: - pending_name.set_exception(HCI_Error(error_code)) if address == peer_address else None + lambda address, error_code: pending_name.set_exception( + HCI_Error(error_code) + ) + if address == peer_address + else None, ) try: result = await self.send_command( HCI_Remote_Name_Request_Command( - bd_addr = peer_address, - page_scan_repetition_mode = HCI_Remote_Name_Request_Command.R0, # TODO investigate other options - reserved = 0, - clock_offset = 0 # TODO investigate non-0 values + bd_addr=peer_address, + page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R0, # TODO investigate other options + reserved=0, + clock_offset=0, # TODO investigate non-0 values ) ) if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warn(f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') + logger.warn( + f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}' + ) raise HCI_StatusError(result) # Wait for the result @@ -1865,7 +2127,7 @@ class Device(CompositeEventEmitter): # Store the keys in the key store if self.keystore: pairing_keys = keys.PairingKeys() - pairing_keys.link_key = keys.PairingKeys.Key(value = link_key) + pairing_keys.link_key = keys.PairingKeys.Key(value=link_key) async def store_keys(): try: @@ -1875,7 +2137,9 @@ class Device(CompositeEventEmitter): asyncio.create_task(store_keys()) - if (connection := self.find_connection_by_bd_addr(bd_addr, transport=BT_BR_EDR_TRANSPORT)): + if connection := self.find_connection_by_bd_addr( + bd_addr, transport=BT_BR_EDR_TRANSPORT + ): connection.link_key_type = key_type def add_service(self, service): @@ -1902,15 +2166,29 @@ class Device(CompositeEventEmitter): await self.gatt_server.indicate_subscribers(attribute, value, force) @host_event_handler - def on_connection(self, connection_handle, transport, peer_address, peer_resolvable_address, role, connection_parameters): - logger.debug(f'*** Connection: [0x{connection_handle:04X}] {peer_address} as {HCI_Constant.role_name(role)}') + def on_connection( + self, + connection_handle, + transport, + peer_address, + peer_resolvable_address, + role, + connection_parameters, + ): + logger.debug( + f'*** Connection: [0x{connection_handle:04X}] {peer_address} as {HCI_Constant.role_name(role)}' + ) if connection_handle in self.connections: - logger.warn('new connection reuses the same handle as a previous connection') + logger.warn( + 'new connection reuses the same handle as a previous connection' + ) if transport == BT_BR_EDR_TRANSPORT: # Create a new connection connection: Connection = self.pending_connections.pop(peer_address) - connection.complete(connection_handle, peer_resolvable_address, role, connection_parameters) + connection.complete( + connection_handle, peer_resolvable_address, role, connection_parameters + ) self.connections[connection_handle] = connection # We may have an accept ongoing waiting for a connection request for `peer_address`. @@ -1943,7 +2221,7 @@ class Device(CompositeEventEmitter): # We are no longer advertising self.advertising_own_address_type = None - self.advertising = False + self.advertising = False # Create and notify of the new connection asynchronously async def new_connection(): @@ -1951,15 +2229,20 @@ class Device(CompositeEventEmitter): if self.host.supports_command(HCI_LE_READ_PHY_COMMAND): result = await self.send_command( HCI_LE_Read_PHY_Command(connection_handle=connection_handle), - check_result=True + check_result=True, + ) + phy = ConnectionPHY( + result.return_parameters.tx_phy, result.return_parameters.rx_phy ) - phy = ConnectionPHY(result.return_parameters.tx_phy, result.return_parameters.rx_phy) else: phy = ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY) self_address = self.random_address - if own_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC): - self_address = self.public_address + if own_address_type in ( + OwnAddressType.PUBLIC, + OwnAddressType.RESOLVABLE_OR_PUBLIC, + ): + self_address = self.public_address # Create a new connection connection = Connection( @@ -1971,7 +2254,7 @@ class Device(CompositeEventEmitter): peer_resolvable_address, role, connection_parameters, - phy + phy, ) self.connections[connection_handle] = connection @@ -1985,9 +2268,13 @@ class Device(CompositeEventEmitter): logger.debug(f'*** Connection failed: {HCI_Constant.error_name(error_code)}') # For directed advertising, this means a timeout - if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed: + if ( + transport == BT_LE_TRANSPORT + and self.advertising + and self.advertising_type.is_directed + ): self.advertising_own_address_type = None - self.advertising = False + self.advertising = False # Notify listeners error = ConnectionError( @@ -1995,7 +2282,7 @@ class Device(CompositeEventEmitter): transport, peer_address, 'hci', - HCI_Constant.error_name(error_code) + HCI_Constant.error_name(error_code), ) self.emit('connection_failure', error) @@ -2021,8 +2308,7 @@ class Device(CompositeEventEmitter): self.host.send_command_sync( HCI_Accept_Connection_Request_Command( - bd_addr = bd_addr, - role = 0x01 # Remain the peripheral + bd_addr=bd_addr, role=0x01 # Remain the peripheral ) ) @@ -2030,15 +2316,17 @@ class Device(CompositeEventEmitter): else: self.host.send_command_sync( HCI_Reject_Connection_Request_Command( - bd_addr = bd_addr, - reason = HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR + bd_addr=bd_addr, + reason=HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR, ) ) @host_event_handler @with_connection_from_handle def on_disconnection(self, connection, reason): - logger.debug(f'*** Disconnection: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, reason={reason}') + logger.debug( + f'*** Disconnection: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, reason={reason}' + ) connection.emit('disconnection', reason) # Remove the connection from the map @@ -2050,10 +2338,11 @@ class Device(CompositeEventEmitter): # Restart advertising if auto-restart is enabled if self.auto_restart_advertising: logger.debug('restarting advertising') - asyncio.create_task(self.start_advertising( - advertising_type = self.advertising_type, - auto_restart = True - )) + asyncio.create_task( + self.start_advertising( + advertising_type=self.advertising_type, auto_restart=True + ) + ) @host_event_handler @with_connection_from_handle @@ -2064,7 +2353,7 @@ class Device(CompositeEventEmitter): connection.transport, connection.peer_address, 'hci', - HCI_Constant.error_name(error_code) + HCI_Constant.error_name(error_code), ) connection.emit('disconnection_failure', error) @@ -2076,20 +2365,24 @@ class Device(CompositeEventEmitter): await self.start_discovery(auto_restart=True) else: self.auto_restart_inquiry = True - self.discovering = False + self.discovering = False self.emit('inquiry_complete') @host_event_handler @with_connection_from_handle def on_connection_authentication(self, connection): - logger.debug(f'*** Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}') + logger.debug( + f'*** Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}' + ) connection.authenticated = True connection.emit('connection_authentication') @host_event_handler @with_connection_from_handle def on_connection_authentication_failure(self, connection, error): - logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') + logger.debug( + f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}' + ) connection.emit('connection_authentication_failure', error) @host_event_handler @@ -2100,7 +2393,9 @@ class Device(CompositeEventEmitter): # - AND we are not the initiator of the authentication # We must trigger authentication to known if we are truly authenticated if not connection.authenticating and not connection.authenticated: - logger.debug(f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}') + logger.debug( + f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}' + ) asyncio.create_task(connection.authenticate()) # [Classic only] @@ -2112,15 +2407,17 @@ class Device(CompositeEventEmitter): # Map the SMP IO capability to a Classic IO capability io_capability = { - smp.SMP_DISPLAY_ONLY_IO_CAPABILITY: HCI_DISPLAY_ONLY_IO_CAPABILITY, - smp.SMP_DISPLAY_YES_NO_IO_CAPABILITY: HCI_DISPLAY_YES_NO_IO_CAPABILITY, - smp.SMP_KEYBOARD_ONLY_IO_CAPABILITY: HCI_KEYBOARD_ONLY_IO_CAPABILITY, + smp.SMP_DISPLAY_ONLY_IO_CAPABILITY: HCI_DISPLAY_ONLY_IO_CAPABILITY, + smp.SMP_DISPLAY_YES_NO_IO_CAPABILITY: HCI_DISPLAY_YES_NO_IO_CAPABILITY, + smp.SMP_KEYBOARD_ONLY_IO_CAPABILITY: HCI_KEYBOARD_ONLY_IO_CAPABILITY, smp.SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, - smp.SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: HCI_DISPLAY_YES_NO_IO_CAPABILITY + smp.SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: HCI_DISPLAY_YES_NO_IO_CAPABILITY, }.get(pairing_config.delegate.io_capability) if io_capability is None: - logger.warning(f'cannot map IO capability ({pairing_config.delegate.io_capability}') + logger.warning( + f'cannot map IO capability ({pairing_config.delegate.io_capability}' + ) io_capability = HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY # Compute the authentication requirements @@ -2128,22 +2425,22 @@ class Device(CompositeEventEmitter): # No Bonding ( HCI_MITM_NOT_REQUIRED_NO_BONDING_AUTHENTICATION_REQUIREMENTS, - HCI_MITM_REQUIRED_NO_BONDING_AUTHENTICATION_REQUIREMENTS + HCI_MITM_REQUIRED_NO_BONDING_AUTHENTICATION_REQUIREMENTS, ), # General Bonding ( HCI_MITM_NOT_REQUIRED_GENERAL_BONDING_AUTHENTICATION_REQUIREMENTS, - HCI_MITM_REQUIRED_GENERAL_BONDING_AUTHENTICATION_REQUIREMENTS - ) + HCI_MITM_REQUIRED_GENERAL_BONDING_AUTHENTICATION_REQUIREMENTS, + ), )[1 if pairing_config.bonding else 0][1 if pairing_config.mitm else 0] # Respond self.host.send_command_sync( HCI_IO_Capability_Request_Reply_Command( - bd_addr = connection.peer_address, - io_capability = io_capability, - oob_data_present = 0x00, # Not present - authentication_requirements = authentication_requirements + bd_addr=connection.peer_address, + io_capability=io_capability, + oob_data_present=0x00, # Not present + authentication_requirements=authentication_requirements, ) ) @@ -2156,33 +2453,45 @@ class Device(CompositeEventEmitter): can_compare = pairing_config.delegate.io_capability not in { smp.SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, - smp.SMP_DISPLAY_ONLY_IO_CAPABILITY + smp.SMP_DISPLAY_ONLY_IO_CAPABILITY, } # Respond if can_compare: + async def compare_numbers(): - numbers_match = await pairing_config.delegate.compare_numbers(code, digits=6) + numbers_match = await pairing_config.delegate.compare_numbers( + code, digits=6 + ) if numbers_match: self.host.send_command_sync( - HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Confirmation_Request_Reply_Command( + bd_addr=connection.peer_address + ) ) else: self.host.send_command_sync( - HCI_User_Confirmation_Request_Negative_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Confirmation_Request_Negative_Reply_Command( + bd_addr=connection.peer_address + ) ) asyncio.create_task(compare_numbers()) else: + async def confirm(): confirm = await pairing_config.delegate.confirm() if confirm: self.host.send_command_sync( - HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Confirmation_Request_Reply_Command( + bd_addr=connection.peer_address + ) ) else: self.host.send_command_sync( - HCI_User_Confirmation_Request_Negative_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Confirmation_Request_Negative_Reply_Command( + bd_addr=connection.peer_address + ) ) asyncio.create_task(confirm()) @@ -2196,28 +2505,33 @@ class Device(CompositeEventEmitter): can_input = pairing_config.delegate.io_capability in { smp.SMP_KEYBOARD_ONLY_IO_CAPABILITY, - smp.SMP_KEYBOARD_DISPLAY_IO_CAPABILITY + smp.SMP_KEYBOARD_DISPLAY_IO_CAPABILITY, } # Respond if can_input: + async def get_number(): number = await pairing_config.delegate.get_number() if number is not None: self.host.send_command_sync( HCI_User_Passkey_Request_Reply_Command( - bd_addr = connection.peer_address, - numeric_value = number) + bd_addr=connection.peer_address, numeric_value=number + ) ) else: self.host.send_command_sync( - HCI_User_Passkey_Request_Negative_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Passkey_Request_Negative_Reply_Command( + bd_addr=connection.peer_address + ) ) asyncio.create_task(get_number()) else: self.host.send_command_sync( - HCI_User_Passkey_Request_Negative_Reply_Command(bd_addr=connection.peer_address) + HCI_User_Passkey_Request_Negative_Reply_Command( + bd_addr=connection.peer_address + ) ) # [Classic only] @@ -2258,60 +2572,85 @@ class Device(CompositeEventEmitter): @host_event_handler @with_connection_from_handle def on_connection_encryption_change(self, connection, encryption): - logger.debug(f'*** Connection Encryption Change: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, encryption={encryption}') + logger.debug( + f'*** Connection Encryption Change: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, encryption={encryption}' + ) connection.encryption = encryption connection.emit('connection_encryption_change') @host_event_handler @with_connection_from_handle def on_connection_encryption_failure(self, connection, error): - logger.debug(f'*** Connection Encryption Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') + logger.debug( + f'*** Connection Encryption Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}' + ) connection.emit('connection_encryption_failure', error) @host_event_handler @with_connection_from_handle def on_connection_encryption_key_refresh(self, connection): - logger.debug(f'*** Connection Key Refresh: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}') + logger.debug( + f'*** Connection Key Refresh: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}' + ) connection.emit('connection_encryption_key_refresh') @host_event_handler @with_connection_from_handle def on_connection_parameters_update(self, connection, connection_parameters): - logger.debug(f'*** Connection Parameters Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {connection_parameters}') + logger.debug( + f'*** Connection Parameters Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {connection_parameters}' + ) connection.parameters = connection_parameters connection.emit('connection_parameters_update') @host_event_handler @with_connection_from_handle def on_connection_parameters_update_failure(self, connection, error): - logger.debug(f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') + logger.debug( + f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}' + ) connection.emit('connection_parameters_update_failure', error) @host_event_handler @with_connection_from_handle def on_connection_phy_update(self, connection, connection_phy): - logger.debug(f'*** Connection PHY Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {connection_phy}') + logger.debug( + f'*** Connection PHY Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {connection_phy}' + ) connection.phy = connection_phy connection.emit('connection_phy_update') @host_event_handler @with_connection_from_handle def on_connection_phy_update_failure(self, connection, error): - logger.debug(f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') + logger.debug( + f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}' + ) connection.emit('connection_phy_update_failure', error) @host_event_handler @with_connection_from_handle def on_connection_att_mtu_update(self, connection, att_mtu): - logger.debug(f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {att_mtu}') + logger.debug( + f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, {att_mtu}' + ) connection.att_mtu = att_mtu connection.emit('connection_att_mtu_update') @host_event_handler @with_connection_from_handle - def on_connection_data_length_change(self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time): - logger.debug(f'*** Connection Data Length Change: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}') - connection.data_length = (max_tx_octets, max_tx_time, max_rx_octets, max_rx_time) + def on_connection_data_length_change( + self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time + ): + logger.debug( + f'*** Connection Data Length Change: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}' + ) + connection.data_length = ( + max_tx_octets, + max_tx_time, + max_rx_octets, + max_rx_time, + ) connection.emit('connection_data_length_change') @with_connection_from_handle @@ -2337,12 +2676,16 @@ class Device(CompositeEventEmitter): # odd-numbered ones are server->client if att_pdu.op_code & 1: if connection.gatt_client is None: - logger.warn(color('no GATT client for connection 0x{connection_handle:04X}')) + logger.warn( + color('no GATT client for connection 0x{connection_handle:04X}') + ) return connection.gatt_client.on_gatt_pdu(att_pdu) else: if connection.gatt_server is None: - logger.warn(color('no GATT server for connection 0x{connection_handle:04X}')) + logger.warn( + color('no GATT server for connection 0x{connection_handle:04X}') + ) return connection.gatt_server.on_gatt_pdu(connection, att_pdu) diff --git a/bumble/gap.py b/bumble/gap.py index 8341215d..a4d5077d 100644 --- a/bumble/gap.py +++ b/bumble/gap.py @@ -23,7 +23,7 @@ from .gatt import ( Characteristic, GATT_GENERIC_ACCESS_SERVICE, GATT_DEVICE_NAME_CHARACTERISTIC, - GATT_APPEARANCE_CHARACTERISTIC + GATT_APPEARANCE_CHARACTERISTIC, ) # ----------------------------------------------------------------------------- @@ -38,22 +38,22 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- class GenericAccessService(Service): - def __init__(self, device_name, appearance = (0, 0)): + def __init__(self, device_name, appearance=(0, 0)): device_name_characteristic = Characteristic( GATT_DEVICE_NAME_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - device_name.encode('utf-8')[:248] + device_name.encode('utf-8')[:248], ) appearance_characteristic = Characteristic( GATT_APPEARANCE_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - struct.pack(' connection.att_mtu - 3: - value = value[:connection.att_mtu - 3] + value = value[: connection.att_mtu - 3] # Notify notification = ATT_Handle_Value_Notification( - attribute_handle = attribute.handle, - attribute_value = value + attribute_handle=attribute.handle, attribute_value=value + ) + logger.debug( + f'GATT Notify from server: [0x{connection.handle:04X}] {notification}' ) - logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}') self.send_gatt_pdu(connection.handle, bytes(notification)) async def indicate_subscriber(self, connection, attribute, value=None, force=False): @@ -273,46 +308,60 @@ class Server(EventEmitter): return cccd = subscribers.get(attribute.handle) if not cccd: - logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}') + logger.debug( + f'not indicating, no subscribers for handle {attribute.handle:04X}' + ) return if len(cccd) != 2 or (cccd[0] & 0x02 == 0): logger.debug(f'not indicating, cccd={cccd.hex()}') return # Get or encode the value - value = attribute.read_value(connection) if value is None else attribute.encode_value(value) + value = ( + attribute.read_value(connection) + if value is None + else attribute.encode_value(value) + ) # Truncate if needed if len(value) > connection.att_mtu - 3: - value = value[:connection.att_mtu - 3] + value = value[: connection.att_mtu - 3] # Indicate indication = ATT_Handle_Value_Indication( - attribute_handle = attribute.handle, - attribute_value = value + attribute_handle=attribute.handle, attribute_value=value + ) + logger.debug( + f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}' ) - logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}') # Wait until we can send (only one pending indication at a time per connection) async with self.indication_semaphores[connection.handle]: - assert(self.pending_confirmations[connection.handle] is None) + assert self.pending_confirmations[connection.handle] is None # Create a future value to hold the eventual response - self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future() + self.pending_confirmations[ + connection.handle + ] = asyncio.get_running_loop().create_future() try: self.send_gatt_pdu(connection.handle, indication.to_bytes()) - await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT) + await asyncio.wait_for( + self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT + ) except asyncio.TimeoutError: logger.warning(color('!!! GATT Indicate timeout', 'red')) raise TimeoutError(f'GATT timeout for {indication.name}') finally: self.pending_confirmations[connection.handle] = None - async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False): + async def notify_or_indicate_subscribers( + self, indicate, attribute, value=None, force=False + ): # Get all the connections for which there's at least one subscription connections = [ - connection for connection in [ + connection + for connection in [ self.device.lookup_connection(connection_handle) for (connection_handle, subscribers) in self.subscribers.items() if force or subscribers.get(attribute.handle) @@ -323,10 +372,12 @@ class Server(EventEmitter): # Indicate or notify for each connection if connections: coroutine = self.indicate_subscriber if indicate else self.notify_subscriber - await asyncio.wait([ - asyncio.create_task(coroutine(connection, attribute, value, force)) - for connection in connections - ]) + await asyncio.wait( + [ + asyncio.create_task(coroutine(connection, attribute, value, force)) + for connection in connections + ] + ) async def notify_subscribers(self, attribute, value=None, force=False): return await self.notify_or_indicate_subscribers(False, attribute, value, force) @@ -352,17 +403,17 @@ class Server(EventEmitter): except ATT_Error as error: logger.debug(f'normal exception returned by handler: {error}') response = ATT_Error_Response( - request_opcode_in_error = att_pdu.op_code, - attribute_handle_in_error = error.att_handle, - error_code = error.error_code + request_opcode_in_error=att_pdu.op_code, + attribute_handle_in_error=error.att_handle, + error_code=error.error_code, ) self.send_response(connection, response) except Exception as error: logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') response = ATT_Error_Response( - request_opcode_in_error = att_pdu.op_code, - attribute_handle_in_error = 0x0000, - error_code = ATT_UNLIKELY_ERROR_ERROR + request_opcode_in_error=att_pdu.op_code, + attribute_handle_in_error=0x0000, + error_code=ATT_UNLIKELY_ERROR_ERROR, ) self.send_response(connection, response) raise error @@ -373,7 +424,9 @@ class Server(EventEmitter): self.on_att_request(connection, att_pdu) else: # Just ignore - logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}') + logger.warning( + f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}' + ) ####################################################### # ATT handlers @@ -382,11 +435,13 @@ class Server(EventEmitter): ''' Handler for requests without a more specific handler ''' - logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}') + logger.warning( + f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}' + ) response = ATT_Error_Response( - request_opcode_in_error = pdu.op_code, - attribute_handle_in_error = 0x0000, - error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR + request_opcode_in_error=pdu.op_code, + attribute_handle_in_error=0x0000, + error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR, ) self.send_response(connection, response) @@ -394,7 +449,9 @@ class Server(EventEmitter): ''' See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request ''' - self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu)) + self.send_response( + connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) + ) # Compute the final MTU if request.client_rx_mtu >= ATT_DEFAULT_MTU: @@ -411,12 +468,18 @@ class Server(EventEmitter): ''' # Check the request parameters - if request.starting_handle == 0 or request.starting_handle > request.ending_handle: - self.send_response(connection, ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.starting_handle, - error_code = ATT_INVALID_HANDLE_ERROR - )) + if ( + request.starting_handle == 0 + or request.starting_handle > request.ending_handle + ): + self.send_response( + connection, + ATT_Error_Response( + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.starting_handle, + error_code=ATT_INVALID_HANDLE_ERROR, + ), + ) return # Build list of returned attributes @@ -424,9 +487,10 @@ class Server(EventEmitter): attributes = [] uuid_size = 0 for attribute in ( - attribute for attribute in self.attributes if - attribute.handle >= request.starting_handle and - attribute.handle <= request.ending_handle + attribute + for attribute in self.attributes + if attribute.handle >= request.starting_handle + and attribute.handle <= request.ending_handle ): # TODO: check permissions @@ -453,14 +517,14 @@ class Server(EventEmitter): for attribute in attributes ] response = ATT_Find_Information_Response( - format = 1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2, - information_data = b''.join(information_data_list) + format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2, + information_data=b''.join(information_data_list), ) else: response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.starting_handle, - error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.starting_handle, + error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) self.send_response(connection, response) @@ -474,12 +538,13 @@ class Server(EventEmitter): pdu_space_available = connection.att_mtu - 2 attributes = [] for attribute in ( - attribute for attribute in self.attributes if - attribute.handle >= request.starting_handle and - attribute.handle <= request.ending_handle and - attribute.type == request.attribute_type and - attribute.read_value(connection) == request.attribute_value and - pdu_space_available >= 4 + attribute + for attribute in self.attributes + if attribute.handle >= request.starting_handle + and attribute.handle <= request.ending_handle + and attribute.type == request.attribute_type + and attribute.read_value(connection) == request.attribute_value + and pdu_space_available >= 4 ): # TODO: check permissions @@ -494,22 +559,24 @@ class Server(EventEmitter): if attribute.type in { GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, - GATT_CHARACTERISTIC_ATTRIBUTE_TYPE + GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, }: # Part of a group group_end_handle = attribute.end_group_handle else: # Not part of a group group_end_handle = attribute.handle - handles_information_list.append(struct.pack('= request.starting_handle and - attribute.handle <= request.ending_handle and - pdu_space_available + attribute + for attribute in self.attributes + if attribute.type == request.attribute_type + and attribute.handle >= request.starting_handle + and attribute.handle <= request.ending_handle + and pdu_space_available ): # TODO: check permissions @@ -550,16 +618,17 @@ class Server(EventEmitter): pdu_space_available -= entry_size if attributes: - attribute_data_list = [struct.pack(' len(value): response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.attribute_handle, - error_code = ATT_INVALID_OFFSET_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.attribute_handle, + error_code=ATT_INVALID_OFFSET_ERROR, ) elif len(value) <= connection.att_mtu - 1: response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.attribute_handle, - error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.attribute_handle, + error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR, ) else: - part_size = min(connection.att_mtu - 1, len(value) - request.value_offset) + part_size = min( + connection.att_mtu - 1, len(value) - request.value_offset + ) response = ATT_Read_Blob_Response( - part_attribute_value = value[request.value_offset:request.value_offset + part_size] + part_attribute_value=value[ + request.value_offset : request.value_offset + part_size + ] ) else: response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.attribute_handle, - error_code = ATT_INVALID_HANDLE_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.attribute_handle, + error_code=ATT_INVALID_HANDLE_ERROR, ) self.send_response(connection, response) @@ -624,12 +695,12 @@ class Server(EventEmitter): if request.attribute_group_type not in { GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, - GATT_INCLUDE_ATTRIBUTE_TYPE + GATT_INCLUDE_ATTRIBUTE_TYPE, }: response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.starting_handle, - error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.starting_handle, + error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ) self.send_response(connection, response) return @@ -637,11 +708,12 @@ class Server(EventEmitter): pdu_space_available = connection.att_mtu - 2 attributes = [] for attribute in ( - attribute for attribute in self.attributes if - attribute.type == request.attribute_group_type and - attribute.handle >= request.starting_handle and - attribute.handle <= request.ending_handle and - pdu_space_available + attribute + for attribute in self.attributes + if attribute.type == request.attribute_group_type + and attribute.handle >= request.starting_handle + and attribute.handle <= request.ending_handle + and pdu_space_available ): # Check the attribute value size attribute_value = attribute.read_value(connection) @@ -659,7 +731,9 @@ class Server(EventEmitter): break # Add the attribute to the list - attributes.append((attribute.handle, attribute.end_group_handle, attribute_value)) + attributes.append( + (attribute.handle, attribute.end_group_handle, attribute_value) + ) pdu_space_available -= entry_size if attributes: @@ -668,14 +742,14 @@ class Server(EventEmitter): for handle, end_group_handle, value in attributes ] response = ATT_Read_By_Group_Type_Response( - length = len(attribute_data_list[0]), - attribute_data_list = b''.join(attribute_data_list) + length=len(attribute_data_list[0]), + attribute_data_list=b''.join(attribute_data_list), ) else: response = ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.starting_handle, - error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.starting_handle, + error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) self.send_response(connection, response) @@ -688,22 +762,28 @@ class Server(EventEmitter): # Check that the attribute exists attribute = self.get_attribute(request.attribute_handle) if attribute is None: - self.send_response(connection, ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.attribute_handle, - error_code = ATT_INVALID_HANDLE_ERROR - )) + self.send_response( + connection, + ATT_Error_Response( + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.attribute_handle, + error_code=ATT_INVALID_HANDLE_ERROR, + ), + ) return # TODO: check permissions # Check the request parameters if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE: - self.send_response(connection, ATT_Error_Response( - request_opcode_in_error = request.op_code, - attribute_handle_in_error = request.attribute_handle, - error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR - )) + self.send_response( + connection, + ATT_Error_Response( + request_opcode_in_error=request.op_code, + attribute_handle_in_error=request.attribute_handle, + error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR, + ), + ) return # Accept the value @@ -740,7 +820,9 @@ class Server(EventEmitter): ''' if self.pending_confirmations[connection.handle] is None: # Not expected! - logger.warning('!!! unexpected confirmation, there is no pending indication') + logger.warning( + '!!! unexpected confirmation, there is no pending indication' + ) return self.pending_confirmations[connection.handle].set_result(None) diff --git a/bumble/hci.py b/bumble/hci.py index 4805a6bd..b9946e2c 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) # Utils # ----------------------------------------------------------------------------- def hci_command_op_code(ogf, ocf): - return (ogf << 10 | ocf) + return ogf << 10 | ocf def key_with_value(dictionary, target_value): @@ -58,7 +58,11 @@ def map_null_terminated_utf8_string(utf8_bytes): def map_class_of_device(class_of_device): - service_classes, major_device_class, minor_device_class = DeviceClass.split_class_of_device(class_of_device) + ( + service_classes, + major_device_class, + minor_device_class, + ) = DeviceClass.split_class_of_device(class_of_device) return f'[{class_of_device:06X}] Services({",".join(DeviceClass.service_class_labels(service_classes))}),Class({DeviceClass.major_device_class_name(major_device_class)}|{DeviceClass.minor_device_class_name(major_device_class, minor_device_class)})' @@ -70,13 +74,14 @@ def phy_list_to_bits(phys): for phy in phys: if phy not in HCI_LE_PHY_TYPE_TO_BIT: raise ValueError('invalid PHY') - phy_bits |= (1 << HCI_LE_PHY_TYPE_TO_BIT[phy]) + phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy] return phy_bits # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off # HCI Version HCI_VERSION_BLUETOOTH_CORE_1_0B = 0 @@ -1349,6 +1354,8 @@ HCI_LE_SUPPORTED_FEATURES_NAMES = { if feature_name.startswith('HCI_') and feature_name.endswith('_LE_SUPPORTED_FEATURE') } +# fmt: on + # ----------------------------------------------------------------------------- STATUS_SPEC = {'size': 1, 'mapper': lambda x: HCI_Constant.status_name(x)} @@ -1382,8 +1389,7 @@ class HCI_Constant: @staticmethod def authentication_requirements_name(authentication_requirements): return HCI_AUTHENTICATION_REQUIREMENTS_NAMES.get( - authentication_requirements, - f'0x{authentication_requirements:02X}' + authentication_requirements, f'0x{authentication_requirements:02X}' ) @staticmethod @@ -1403,7 +1409,7 @@ class HCI_StatusError(ProtocolError): super().__init__( response.status, error_namespace=HCI_Command.command_name(response.command_opcode), - error_name=HCI_Constant.status_name(response.status) + error_name=HCI_Constant.status_name(response.status), ) @@ -1463,7 +1469,7 @@ class HCI_Object: offset += 1 elif field_type == 3: # 24-bit unsigned - padded = data[offset:offset + 3] + bytes([0]) + padded = data[offset : offset + 3] + bytes([0]) field_value = struct.unpack(' 4 and field_type <= 256: # Byte array (from 5 up to 256 bytes) - field_value = data[offset:offset + field_type] + field_value = data[offset : offset + field_type] offset += field_type elif callable(field_type): offset, field_value = field_type(data, offset) @@ -1535,7 +1541,11 @@ class HCI_Object: raise ValueError('value too large for *-typed field') else: field_bytes = bytes(field_value) - elif type(field_value) is bytes or type(field_value) is bytearray or hasattr(field_value, 'to_bytes'): + elif ( + type(field_value) is bytes + or type(field_value) is bytearray + or hasattr(field_value, 'to_bytes') + ): field_bytes = bytes(field_value) if type(field_type) is int and field_type > 4 and field_type <= 256: # Truncate or Pad with zeros if the field is too long or too short @@ -1544,7 +1554,9 @@ class HCI_Object: elif len(field_bytes) > field_type: field_bytes = field_bytes[:field_type] else: - raise ValueError(f"don't know how to serialize type {type(field_value)}") + raise ValueError( + f"don't know how to serialize type {type(field_value)}" + ) result += field_bytes @@ -1560,12 +1572,14 @@ class HCI_Object: @staticmethod def parse_length_prefixed_bytes(data, offset): length = data[offset] - return offset + 1 + length, data[offset + 1:offset + 1 + length] + return offset + 1 + length, data[offset + 1 : offset + 1 + length] @staticmethod def serialize_length_prefixed_bytes(data, padded_size=0): prefixed_size = 1 + len(data) - padding = bytes(padded_size - prefixed_size) if prefixed_size < padded_size else b'' + padding = ( + bytes(padded_size - prefixed_size) if prefixed_size < padded_size else b'' + ) return bytes([len(data)]) + data + padding @staticmethod @@ -1583,7 +1597,9 @@ class HCI_Object: return '' # Measure the widest field name - max_field_name_length = max([len(key[0] if type(key) is tuple else key) for key in keys]) + max_field_name_length = max( + [len(key[0] if type(key) is tuple else key) for key in keys] + ) # Build array of formatted key:value pairs fields = [] @@ -1606,7 +1622,9 @@ class HCI_Object: value = value_mapper(value) # Get the string representation of the value - value_str = HCI_Object.format_field_value(value, indentation = indentation + ' ') + value_str = HCI_Object.format_field_value( + value, indentation=indentation + ' ' + ) # Add the field to the formatted result key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan') @@ -1622,7 +1640,9 @@ class HCI_Object: self.init_from_fields(self, fields, kwargs) def to_string(self, indentation='', value_mappers={}): - return HCI_Object.format_fields(self.__dict__, self.fields, indentation, value_mappers) + return HCI_Object.format_fields( + self.__dict__, self.fields, indentation, value_mappers + ) def __str__(self): return self.to_string() @@ -1638,16 +1658,16 @@ class Address: address[0] is the LSB of the address, address[5] is the MSB. ''' - PUBLIC_DEVICE_ADDRESS = 0x00 - RANDOM_DEVICE_ADDRESS = 0x01 + PUBLIC_DEVICE_ADDRESS = 0x00 + RANDOM_DEVICE_ADDRESS = 0x01 PUBLIC_IDENTITY_ADDRESS = 0x02 RANDOM_IDENTITY_ADDRESS = 0x03 ADDRESS_TYPE_NAMES = { - PUBLIC_DEVICE_ADDRESS: 'PUBLIC_DEVICE_ADDRESS', - RANDOM_DEVICE_ADDRESS: 'RANDOM_DEVICE_ADDRESS', + PUBLIC_DEVICE_ADDRESS: 'PUBLIC_DEVICE_ADDRESS', + RANDOM_DEVICE_ADDRESS: 'RANDOM_DEVICE_ADDRESS', PUBLIC_IDENTITY_ADDRESS: 'PUBLIC_IDENTITY_ADDRESS', - RANDOM_IDENTITY_ADDRESS: 'RANDOM_IDENTITY_ADDRESS' + RANDOM_IDENTITY_ADDRESS: 'RANDOM_IDENTITY_ADDRESS', } ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)} @@ -1667,18 +1687,20 @@ class Address: @staticmethod def parse_address(data, offset): # Fix the type to a default value. This is used for parsing type-less Classic addresses - return Address.parse_address_with_type(data, offset, Address.PUBLIC_DEVICE_ADDRESS) + return Address.parse_address_with_type( + data, offset, Address.PUBLIC_DEVICE_ADDRESS + ) @staticmethod def parse_address_with_type(data, offset, address_type): - return offset + 6, Address(data[offset:offset + 6], address_type) + return offset + 6, Address(data[offset : offset + 6], address_type) @staticmethod def parse_address_preceded_by_type(data, offset): address_type = data[offset - 1] return Address.parse_address_with_type(data, offset, address_type) - def __init__(self, address, address_type = RANDOM_DEVICE_ADDRESS): + def __init__(self, address, address_type=RANDOM_DEVICE_ADDRESS): ''' Initialize an instance. `address` may be a byte array in little-endian format, or a hex string in big-endian format (with optional ':' @@ -1709,7 +1731,10 @@ class Address: @property def is_public(self): - return self.address_type == self.PUBLIC_DEVICE_ADDRESS or self.address_type == self.PUBLIC_IDENTITY_ADDRESS + return ( + self.address_type == self.PUBLIC_DEVICE_ADDRESS + or self.address_type == self.PUBLIC_IDENTITY_ADDRESS + ) @property def is_random(self): @@ -1717,11 +1742,16 @@ class Address: @property def is_resolved(self): - return self.address_type == self.PUBLIC_IDENTITY_ADDRESS or self.address_type == self.RANDOM_IDENTITY_ADDRESS + return ( + self.address_type == self.PUBLIC_IDENTITY_ADDRESS + or self.address_type == self.RANDOM_IDENTITY_ADDRESS + ) @property def is_resolvable(self): - return self.address_type == self.RANDOM_DEVICE_ADDRESS and (self.address_bytes[5] >> 6 == 1) + return self.address_type == self.RANDOM_DEVICE_ADDRESS and ( + self.address_bytes[5] >> 6 == 1 + ) @property def is_static(self): @@ -1737,7 +1767,10 @@ class Address: return hash(self.address_bytes) def __eq__(self, other): - return self.address_bytes == other.address_bytes and self.is_public == other.is_public + return ( + self.address_bytes == other.address_bytes + and self.is_public == other.is_public + ) def __str__(self): ''' @@ -1761,10 +1794,10 @@ class OwnAddressType: RESOLVABLE_OR_RANDOM = 3 TYPE_NAMES = { - PUBLIC: 'PUBLIC', - RANDOM: 'RANDOM', + PUBLIC: 'PUBLIC', + RANDOM: 'RANDOM', RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC', - RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM' + RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM', } @staticmethod @@ -1773,6 +1806,7 @@ class OwnAddressType: TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)} + # ----------------------------------------------------------------------------- class HCI_Packet: ''' @@ -1803,7 +1837,7 @@ class HCI_CustomPacket(HCI_Packet): def __init__(self, payload): super().__init__('HCI_CUSTOM_PACKET') self.hci_packet_type = payload[0] - self.payload = payload + self.payload = payload # ----------------------------------------------------------------------------- @@ -1811,6 +1845,7 @@ class HCI_Command(HCI_Packet): ''' See Bluetooth spec @ Vol 2, Part E - 5.4.1 HCI Command Packet ''' + hci_packet_type = HCI_COMMAND_PACKET command_classes = {} @@ -1830,8 +1865,10 @@ class HCI_Command(HCI_Packet): # Patch the __init__ method to fix the op_code if fields is not None: + def init(self, parameters=None, **kwargs): return HCI_Command.__init__(self, cls.op_code, parameters, **kwargs) + cls.__init__ = init # Register a factory for this class @@ -1880,12 +1917,15 @@ class HCI_Command(HCI_Packet): HCI_Object.init_from_fields(self, fields, kwargs) if parameters is None: parameters = HCI_Object.dict_to_bytes(kwargs, fields) - self.op_code = op_code + self.op_code = op_code self.parameters = parameters def to_bytes(self): parameters = b'' if self.parameters is None else self.parameters - return struct.pack('> 5) & 3]) + event_type_flags.append( + ('COMPLETE', 'INCOMPLETE+', 'INCOMPLETE#', '?')[(event_type >> 5) & 3] + ) - if event_type & (1 << HCI_LE_Extended_Advertising_Report_Event.LEGACY_ADVERTISING_PDU_USED): - legacy_pdu_type = HCI_LE_Extended_Advertising_Report_Event.LEGACY_PDU_TYPE_MAP.get(event_type & 0x0F) + if event_type & ( + 1 << HCI_LE_Extended_Advertising_Report_Event.LEGACY_ADVERTISING_PDU_USED + ): + legacy_pdu_type = ( + HCI_LE_Extended_Advertising_Report_Event.LEGACY_PDU_TYPE_MAP.get( + event_type & 0x0F + ) + ) if legacy_pdu_type is not None: legacy_info_string = f'({HCI_LE_Advertising_Report_Event.event_type_name(legacy_pdu_type)})' else: @@ -4232,23 +4496,26 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): self.reports = reports[:] # Serialize the fields - parameters = bytes([HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT, len(reports)]) + b''.join([bytes(report) for report in reports]) + parameters = bytes( + [HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT, len(reports)] + ) + b''.join([bytes(report) for report in reports]) super().__init__(self.subevent_code, parameters) def __str__(self): - reports = '\n'.join([f'{i}:\n{report.to_string(" ")}' for i, report in enumerate(self.reports)]) + reports = '\n'.join( + [f'{i}:\n{report.to_string(" ")}' for i, report in enumerate(self.reports)] + ) return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}' -HCI_Event.meta_event_classes[HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT] = HCI_LE_Extended_Advertising_Report_Event +HCI_Event.meta_event_classes[ + HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT +] = HCI_LE_Extended_Advertising_Report_Event # ----------------------------------------------------------------------------- -@HCI_LE_Meta_Event.event([ - ('connection_handle', 2), - ('channel_selection_algorithm', 1) -]) +@HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)]) class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event): ''' See Bluetooth spec @ 7.7.65.20 LE Channel Selection Algorithm Event @@ -4256,9 +4523,7 @@ class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC) -]) +@HCI_Event.event([('status', STATUS_SPEC)]) class HCI_Inquiry_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.1 Inquiry Complete Event @@ -4273,12 +4538,12 @@ class HCI_Inquiry_Result_Event(HCI_Event): ''' RESPONSE_FIELDS = [ - ('bd_addr', Address.parse_address), + ('bd_addr', Address.parse_address), ('page_scan_repetition_mode', 1), - ('reserved', 1), - ('reserved', 1), - ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), - ('clock_offset', 2) + ('reserved', 1), + ('reserved', 1), + ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), + ('clock_offset', 2), ] @staticmethod @@ -4287,7 +4552,9 @@ class HCI_Inquiry_Result_Event(HCI_Event): responses = [] offset = 1 for _ in range(num_responses): - response = HCI_Object.from_bytes(parameters, offset, HCI_Inquiry_Result_Event.RESPONSE_FIELDS) + response = HCI_Object.from_bytes( + parameters, offset, HCI_Inquiry_Result_Event.RESPONSE_FIELDS + ) offset += 14 responses.append(response) @@ -4297,36 +4564,48 @@ class HCI_Inquiry_Result_Event(HCI_Event): self.responses = responses[:] # Serialize the fields - parameters = bytes([HCI_INQUIRY_RESULT_EVENT, len(responses)]) + b''.join([bytes(response) for response in responses]) + parameters = bytes([HCI_INQUIRY_RESULT_EVENT, len(responses)]) + b''.join( + [bytes(response) for response in responses] + ) super().__init__(HCI_INQUIRY_RESULT_EVENT, parameters) def __str__(self): - responses = '\n'.join([response.to_string(indentation=' ') for response in self.responses]) + responses = '\n'.join( + [response.to_string(indentation=' ') for response in self.responses] + ) return f'{color("HCI_INQUIRY_RESULT_EVENT", "magenta")}:\n{responses}' # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('bd_addr', Address.parse_address), - ('link_type', {'size': 1, 'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x)}), - ('encryption_enabled', 1) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('bd_addr', Address.parse_address), + ( + 'link_type', + { + 'size': 1, + 'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x), + }, + ), + ('encryption_enabled', 1), + ] +) class HCI_Connection_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.3 Connection Complete Event ''' - SCO_LINK_TYPE = 0x00 - ACL_LINK_TYPE = 0x01 + SCO_LINK_TYPE = 0x00 + ACL_LINK_TYPE = 0x01 ESCO_LINK_TYPE = 0x02 LINK_TYPE_NAMES = { - SCO_LINK_TYPE: 'SCO', - ACL_LINK_TYPE: 'ACL', - ESCO_LINK_TYPE: 'eSCO' + SCO_LINK_TYPE: 'SCO', + ACL_LINK_TYPE: 'ACL', + ESCO_LINK_TYPE: 'eSCO', } @staticmethod @@ -4335,11 +4614,19 @@ class HCI_Connection_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('class_of_device', 3), - ('link_type', {'size': 1, 'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x)}) -]) +@HCI_Event.event( + [ + ('bd_addr', Address.parse_address), + ('class_of_device', 3), + ( + 'link_type', + { + 'size': 1, + 'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x), + }, + ), + ] +) class HCI_Connection_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.4 Connection Request Event @@ -4347,11 +4634,13 @@ class HCI_Connection_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('reason', {'size': 1, 'mapper': HCI_Constant.error_name}) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('reason', {'size': 1, 'mapper': HCI_Constant.error_name}), + ] +) class HCI_Disconnection_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.5 Disconnection Complete Event @@ -4359,10 +4648,7 @@ class HCI_Disconnection_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2) -]) +@HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)]) class HCI_Authentication_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.6 Authentication Complete Event @@ -4370,11 +4656,13 @@ class HCI_Authentication_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('bd_addr', Address.parse_address), - ('remote_name', {'size': 248, 'mapper': map_null_terminated_utf8_string}) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('bd_addr', Address.parse_address), + ('remote_name', {'size': 248, 'mapper': map_null_terminated_utf8_string}), + ] +) class HCI_Remote_Name_Request_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.7 Remote Name Request Complete Event @@ -4382,37 +4670,47 @@ class HCI_Remote_Name_Request_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('encryption_enabled', {'size': 1, 'mapper': lambda x: HCI_Encryption_Change_Event.encryption_enabled_name(x)}) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ( + 'encryption_enabled', + { + 'size': 1, + 'mapper': lambda x: HCI_Encryption_Change_Event.encryption_enabled_name( + x + ), + }, + ), + ] +) class HCI_Encryption_Change_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.8 Encryption Change Event ''' - OFF = 0x00 + OFF = 0x00 E0_OR_AES_CCM = 0x01 - AES_CCM = 0x02 + AES_CCM = 0x02 ENCRYPTION_ENABLED_NAMES = { - OFF: 'OFF', + OFF: 'OFF', E0_OR_AES_CCM: 'E0_OR_AES_CCM', - AES_CCM: 'AES_CCM' + AES_CCM: 'AES_CCM', } @staticmethod def encryption_enabled_name(encryption_enabled): - return name_or_number(HCI_Encryption_Change_Event.ENCRYPTION_ENABLED_NAMES, encryption_enabled) + return name_or_number( + HCI_Encryption_Change_Event.ENCRYPTION_ENABLED_NAMES, encryption_enabled + ) # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('lmp_features', 8) -]) +@HCI_Event.event( + [('status', STATUS_SPEC), ('connection_handle', 2), ('lmp_features', 8)] +) class HCI_Read_Remote_Supported_Features_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.11 Read Remote Supported Features Complete Event @@ -4420,13 +4718,15 @@ class HCI_Read_Remote_Supported_Features_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('version', 1), - ('manufacturer_name', 2), - ('subversion', 2) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('version', 1), + ('manufacturer_name', 2), + ('subversion', 2), + ] +) class HCI_Read_Remote_Version_Information_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.12 Read Remote Version Information Complete Event @@ -4434,11 +4734,13 @@ class HCI_Read_Remote_Version_Information_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('num_hci_command_packets', 1), - ('command_opcode', {'size': 2, 'mapper': HCI_Command.command_name}), - ('return_parameters', '*') -]) +@HCI_Event.event( + [ + ('num_hci_command_packets', 1), + ('command_opcode', {'size': 2, 'mapper': HCI_Command.command_name}), + ('return_parameters', '*'), + ] +) class HCI_Command_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.14 Command Complete Event @@ -4459,7 +4761,9 @@ class HCI_Command_Complete_Event(HCI_Event): def from_parameters(parameters): self = HCI_Command_Complete_Event.__new__(HCI_Command_Complete_Event) HCI_Event.__init__(self, self.event_code, parameters) - HCI_Object.init_from_bytes(self, parameters, 0, HCI_Command_Complete_Event.fields) + HCI_Object.init_from_bytes( + self, parameters, 0, HCI_Command_Complete_Event.fields + ) # Parse the return parameters if type(self.return_parameters) is bytes and len(self.return_parameters) == 1: @@ -4468,27 +4772,38 @@ class HCI_Command_Complete_Event(HCI_Event): else: cls = HCI_Command.command_classes.get(self.command_opcode) if cls and cls.return_parameters_fields: - self.return_parameters = HCI_Object.from_bytes(self.return_parameters, 0, cls.return_parameters_fields) + self.return_parameters = HCI_Object.from_bytes( + self.return_parameters, 0, cls.return_parameters_fields + ) self.return_parameters.fields = cls.return_parameters_fields return self def __str__(self): - return f'{color(self.name, "magenta")}:\n' + HCI_Object.format_fields(self.__dict__, self.fields, ' ', { - 'return_parameters': self.map_return_parameters - }) + return f'{color(self.name, "magenta")}:\n' + HCI_Object.format_fields( + self.__dict__, + self.fields, + ' ', + {'return_parameters': self.map_return_parameters}, + ) # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', {'size': 1, 'mapper': lambda x: HCI_Command_Status_Event.status_name(x)}), - ('num_hci_command_packets', 1), - ('command_opcode', {'size': 2, 'mapper': HCI_Command.command_name}) -]) +@HCI_Event.event( + [ + ( + 'status', + {'size': 1, 'mapper': lambda x: HCI_Command_Status_Event.status_name(x)}, + ), + ('num_hci_command_packets', 1), + ('command_opcode', {'size': 2, 'mapper': HCI_Command.command_name}), + ] +) class HCI_Command_Status_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.15 Command Complete Event ''' + PENDING = 0 @staticmethod @@ -4500,11 +4815,13 @@ class HCI_Command_Status_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('bd_addr', Address.parse_address), - ('new_role', {'size': 1, 'mapper': HCI_Constant.role_name}) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('bd_addr', Address.parse_address), + ('new_role', {'size': 1, 'mapper': HCI_Constant.role_name}), + ] +) class HCI_Role_Change_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.18 Role Change Event @@ -4549,34 +4866,46 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event): def __str__(self): lines = [ color(self.name, 'magenta') + ':', - color(' number_of_handles: ', 'cyan') + f'{len(self.connection_handles)}' + color(' number_of_handles: ', 'cyan') + + f'{len(self.connection_handles)}', ] for i in range(len(self.connection_handles)): - lines.append(color(f' connection_handle[{i}]: ', 'cyan') + f'{self.connection_handles[i]}') - lines.append(color(f' num_completed_packets[{i}]: ', 'cyan') + f'{self.num_completed_packets[i]}') + lines.append( + color(f' connection_handle[{i}]: ', 'cyan') + + f'{self.connection_handles[i]}' + ) + lines.append( + color(f' num_completed_packets[{i}]: ', 'cyan') + + f'{self.num_completed_packets[i]}' + ) return '\n'.join(lines) # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('current_mode', {'size': 1, 'mapper': lambda x: HCI_Mode_Change_Event.mode_name(x)}), - ('interval', 2) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ( + 'current_mode', + {'size': 1, 'mapper': lambda x: HCI_Mode_Change_Event.mode_name(x)}, + ), + ('interval', 2), + ] +) class HCI_Mode_Change_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.20 Mode Change Event ''' ACTIVE_MODE = 0x00 - HOLD_MODE = 0x01 - SNIFF_MODE = 0x02 + HOLD_MODE = 0x01 + SNIFF_MODE = 0x02 MODE_NAMES = { ACTIVE_MODE: 'ACTIVE_MODE', - HOLD_MODE: 'HOLD_MODE', - SNIFF_MODE: 'SNIFF_MODE' + HOLD_MODE: 'HOLD_MODE', + SNIFF_MODE: 'SNIFF_MODE', } @staticmethod @@ -4585,9 +4914,7 @@ class HCI_Mode_Change_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address) -]) +@HCI_Event.event([('bd_addr', Address.parse_address)]) class HCI_PIN_Code_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.22 PIN Code Request Event @@ -4595,9 +4922,7 @@ class HCI_PIN_Code_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address) -]) +@HCI_Event.event([('bd_addr', Address.parse_address)]) class HCI_Link_Key_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.24 7.7.23 Link Key Request Event @@ -4605,11 +4930,13 @@ class HCI_Link_Key_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('link_key', 16), - ('key_type', {'size': 1, 'mapper': HCI_Constant.link_key_type_name}) -]) +@HCI_Event.event( + [ + ('bd_addr', Address.parse_address), + ('link_key', 16), + ('key_type', {'size': 1, 'mapper': HCI_Constant.link_key_type_name}), + ] +) class HCI_Link_Key_Notification_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.24 Link Key Notification Event @@ -4617,10 +4944,7 @@ class HCI_Link_Key_Notification_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('connection_handle', 2), - ('lmp_max_slots', 1) -]) +@HCI_Event.event([('connection_handle', 2), ('lmp_max_slots', 1)]) class HCI_Max_Slots_Change_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.27 Max Slots Change Event @@ -4628,11 +4952,9 @@ class HCI_Max_Slots_Change_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('clock_offset', 2) -]) +@HCI_Event.event( + [('status', STATUS_SPEC), ('connection_handle', 2), ('clock_offset', 2)] +) class HCI_Read_Clock_Offset_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.28 Read Clock Offset Complete Event @@ -4640,11 +4962,9 @@ class HCI_Read_Clock_Offset_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('packet_type', 2) -]) +@HCI_Event.event( + [('status', STATUS_SPEC), ('connection_handle', 2), ('packet_type', 2)] +) class HCI_Connection_Packet_Type_Changed_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.29 Connection Packet Type Changed Event @@ -4652,10 +4972,7 @@ class HCI_Connection_Packet_Type_Changed_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('page_scan_repetition_mode', 1) -]) +@HCI_Event.event([('bd_addr', Address.parse_address), ('page_scan_repetition_mode', 1)]) class HCI_Page_Scan_Repetition_Mode_Change_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.31 Page Scan Repetition Mode Change Event @@ -4670,12 +4987,12 @@ class HCI_Inquiry_Result_With_RSSI_Event(HCI_Event): ''' RESPONSE_FIELDS = [ - ('bd_addr', Address.parse_address), + ('bd_addr', Address.parse_address), ('page_scan_repetition_mode', 1), - ('reserved', 1), - ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), - ('clock_offset', 2), - ('rssi', -1) + ('reserved', 1), + ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), + ('clock_offset', 2), + ('rssi', -1), ] @staticmethod @@ -4684,7 +5001,9 @@ class HCI_Inquiry_Result_With_RSSI_Event(HCI_Event): responses = [] offset = 1 for _ in range(num_responses): - response = HCI_Object.from_bytes(parameters, offset, HCI_Inquiry_Result_With_RSSI_Event.RESPONSE_FIELDS) + response = HCI_Object.from_bytes( + parameters, offset, HCI_Inquiry_Result_With_RSSI_Event.RESPONSE_FIELDS + ) offset += 14 responses.append(response) @@ -4694,23 +5013,29 @@ class HCI_Inquiry_Result_With_RSSI_Event(HCI_Event): self.responses = responses[:] # Serialize the fields - parameters = bytes([HCI_INQUIRY_RESULT_WITH_RSSI_EVENT, len(responses)]) + b''.join([bytes(response) for response in responses]) + parameters = bytes( + [HCI_INQUIRY_RESULT_WITH_RSSI_EVENT, len(responses)] + ) + b''.join([bytes(response) for response in responses]) super().__init__(HCI_INQUIRY_RESULT_WITH_RSSI_EVENT, parameters) def __str__(self): - responses = '\n'.join([response.to_string(indentation=' ') for response in self.responses]) + responses = '\n'.join( + [response.to_string(indentation=' ') for response in self.responses] + ) return f'{color("HCI_INQUIRY_RESULT_WITH_RSSI_EVENT", "magenta")}:\n{responses}' # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('page_number', 1), - ('maximum_page_number', 1), - ('extended_lmp_features', 8) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('page_number', 1), + ('maximum_page_number', 1), + ('extended_lmp_features', 8), + ] +) class HCI_Read_Remote_Extended_Features_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.34 Read Remote Extended Features Complete Event @@ -4718,17 +5043,35 @@ class HCI_Read_Remote_Extended_Features_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('bd_addr', Address.parse_address), - ('link_type', {'size': 1, 'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.link_type_name(x)}), - ('transmission_interval', 1), - ('retransmission_window', 1), - ('rx_packet_length', 2), - ('tx_packet_length', 2), - ('air_mode', {'size': 1, 'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.air_mode_name(x)}), -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('bd_addr', Address.parse_address), + ( + 'link_type', + { + 'size': 1, + 'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.link_type_name( + x + ), + }, + ), + ('transmission_interval', 1), + ('retransmission_window', 1), + ('rx_packet_length', 2), + ('tx_packet_length', 2), + ( + 'air_mode', + { + 'size': 1, + 'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.air_mode_name( + x + ), + }, + ), + ] +) class HCI_Synchronous_Connection_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.35 Synchronous Connection Complete Event @@ -4738,40 +5081,46 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event): ESCO_CONNECTION_LINK_TYPE = 0x02 LINK_TYPE_NAMES = { - SCO_CONNECTION_LINK_TYPE: 'SCO', - ESCO_CONNECTION_LINK_TYPE: 'eSCO' + SCO_CONNECTION_LINK_TYPE: 'SCO', + ESCO_CONNECTION_LINK_TYPE: 'eSCO', } - U_LAW_LOG_AIR_MODE = 0x00 - A_LAW_LOG_AIR_MORE = 0x01 - CVSD_AIR_MODE = 0x02 + U_LAW_LOG_AIR_MODE = 0x00 + A_LAW_LOG_AIR_MORE = 0x01 + CVSD_AIR_MODE = 0x02 TRANSPARENT_DATA_AIR_MODE = 0x03 AIR_MODE_NAMES = { - U_LAW_LOG_AIR_MODE: 'u-law log', - A_LAW_LOG_AIR_MORE: 'A-law log', - CVSD_AIR_MODE: 'CVSD', - TRANSPARENT_DATA_AIR_MODE: 'Transparent Data' + U_LAW_LOG_AIR_MODE: 'u-law log', + A_LAW_LOG_AIR_MORE: 'A-law log', + CVSD_AIR_MODE: 'CVSD', + TRANSPARENT_DATA_AIR_MODE: 'Transparent Data', } @staticmethod def link_type_name(link_type): - return name_or_number(HCI_Synchronous_Connection_Complete_Event.LINK_TYPE_NAMES, link_type) + return name_or_number( + HCI_Synchronous_Connection_Complete_Event.LINK_TYPE_NAMES, link_type + ) @staticmethod def air_mode_name(air_mode): - return name_or_number(HCI_Synchronous_Connection_Complete_Event.AIR_MODE_NAMES, air_mode) + return name_or_number( + HCI_Synchronous_Connection_Complete_Event.AIR_MODE_NAMES, air_mode + ) # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2), - ('transmission_interval', 1), - ('retransmission_window', 1), - ('rx_packet_length', 2), - ('tx_packet_length', 2) -]) +@HCI_Event.event( + [ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ('transmission_interval', 1), + ('retransmission_window', 1), + ('rx_packet_length', 2), + ('tx_packet_length', 2), + ] +) class HCI_Synchronous_Connection_Changed_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.36 Synchronous Connection Changed Event @@ -4779,16 +5128,18 @@ class HCI_Synchronous_Connection_Changed_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('num_responses', 1), - ('bd_addr', Address.parse_address), - ('page_scan_repetition_mode', 1), - ('reserved', 1), - ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), - ('clock_offset', 2), - ('rssi', -1), - ('extended_inquiry_response', 240), -]) +@HCI_Event.event( + [ + ('num_responses', 1), + ('bd_addr', Address.parse_address), + ('page_scan_repetition_mode', 1), + ('reserved', 1), + ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), + ('clock_offset', 2), + ('rssi', -1), + ('extended_inquiry_response', 240), + ] +) class HCI_Extended_Inquiry_Result_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.38 Extended Inquiry Result Event @@ -4796,10 +5147,7 @@ class HCI_Extended_Inquiry_Result_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('connection_handle', 2) -]) +@HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)]) class HCI_Encryption_Key_Refresh_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.39 Encryption Key Refresh Complete Event @@ -4807,9 +5155,7 @@ class HCI_Encryption_Key_Refresh_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address) -]) +@HCI_Event.event([('bd_addr', Address.parse_address)]) class HCI_IO_Capability_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.40 IO Capability Request Event @@ -4817,12 +5163,17 @@ class HCI_IO_Capability_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('io_capability', {'size': 1, 'mapper': HCI_Constant.io_capability_name}), - ('oob_data_present', 1), - ('authentication_requirements', {'size': 1, 'mapper': HCI_Constant.authentication_requirements_name}) -]) +@HCI_Event.event( + [ + ('bd_addr', Address.parse_address), + ('io_capability', {'size': 1, 'mapper': HCI_Constant.io_capability_name}), + ('oob_data_present', 1), + ( + 'authentication_requirements', + {'size': 1, 'mapper': HCI_Constant.authentication_requirements_name}, + ), + ] +) class HCI_IO_Capability_Response_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.41 IO Capability Response Event @@ -4830,10 +5181,7 @@ class HCI_IO_Capability_Response_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('numeric_value', 4) -]) +@HCI_Event.event([('bd_addr', Address.parse_address), ('numeric_value', 4)]) class HCI_User_Confirmation_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.42 User Confirmation Request Event @@ -4841,9 +5189,7 @@ class HCI_User_Confirmation_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address) -]) +@HCI_Event.event([('bd_addr', Address.parse_address)]) class HCI_User_Passkey_Request_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.43 User Passkey Request Event @@ -4851,10 +5197,7 @@ class HCI_User_Passkey_Request_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('status', STATUS_SPEC), - ('bd_addr', Address.parse_address) -]) +@HCI_Event.event([('status', STATUS_SPEC), ('bd_addr', Address.parse_address)]) class HCI_Simple_Pairing_Complete_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.45 Simple Pairing Complete Event @@ -4862,10 +5205,7 @@ class HCI_Simple_Pairing_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('connection_handle', 2), - ('link_supervision_timeout', 2) -]) +@HCI_Event.event([('connection_handle', 2), ('link_supervision_timeout', 2)]) class HCI_Link_Supervision_Timeout_Changed_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.46 Link Supervision Timeout Changed Event @@ -4873,10 +5213,7 @@ class HCI_Link_Supervision_Timeout_Changed_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('passkey', 4) -]) +@HCI_Event.event([('bd_addr', Address.parse_address), ('passkey', 4)]) class HCI_User_Passkey_Notification_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.48 User Passkey Notification Event @@ -4884,10 +5221,7 @@ class HCI_User_Passkey_Notification_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.event([ - ('bd_addr', Address.parse_address), - ('host_supported_features', 8) -]) +@HCI_Event.event([('bd_addr', Address.parse_address), ('host_supported_features', 8)]) class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.50 Remote Host Supported Features Notification Event @@ -4899,6 +5233,7 @@ class HCI_AclDataPacket(HCI_Packet): ''' See Bluetooth spec @ 5.4.2 HCI ACL Data Packets ''' + hci_packet_type = HCI_ACL_DATA_PACKET @staticmethod @@ -4911,11 +5246,16 @@ class HCI_AclDataPacket(HCI_Packet): data = packet[5:] if len(data) != data_total_length: raise ValueError('invalid packet length') - return HCI_AclDataPacket(connection_handle, pb_flag, bc_flag, data_total_length, data) + return HCI_AclDataPacket( + connection_handle, pb_flag, bc_flag, data_total_length, data + ) def to_bytes(self): h = (self.pb_flag << 12) | (self.bc_flag << 14) | self.connection_handle - return struct.pack(' self.l2cap_pdu_length + 4: logger.warning('!!! ACL data exceeds L2CAP PDU') - self.current_data = None + self.current_data = None self.l2cap_pdu_length = 0 diff --git a/bumble/helpers.py b/bumble/helpers.py index e9ebf972..5393ee02 100644 --- a/bumble/helpers.py +++ b/bumble/helpers.py @@ -29,20 +29,17 @@ from .l2cap import ( L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID, L2CAP_Control_Frame, - L2CAP_Connection_Response + L2CAP_Connection_Response, ) from .hci import ( HCI_EVENT_PACKET, HCI_ACL_DATA_PACKET, HCI_DISCONNECTION_COMPLETE_EVENT, - HCI_AclDataPacketAssembler + HCI_AclDataPacketAssembler, ) from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from .sdp import SDP_PDU, SDP_PSM -from .avdtp import ( - MessageAssembler as AVDTP_MessageAssembler, - AVDTP_PSM -) +from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM # ----------------------------------------------------------------------------- # Logging @@ -53,8 +50,8 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- PSM_NAMES = { RFCOMM_PSM: 'RFCOMM', - SDP_PSM: 'SDP', - AVDTP_PSM: 'AVDTP' + SDP_PSM: 'SDP', + AVDTP_PSM: 'AVDTP' # TODO: add more PSM values } @@ -63,11 +60,11 @@ PSM_NAMES = { class PacketTracer: class AclStream: def __init__(self, analyzer): - self.analyzer = analyzer + self.analyzer = analyzer self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) - self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid - self.psms = {} # PSM, by source_cid - self.peer = None # ACL stream in the other direction + self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid + self.psms = {} # PSM, by source_cid + self.peer = None # ACL stream in the other direction def on_acl_pdu(self, pdu): l2cap_pdu = L2CAP_PDU.from_bytes(pdu) @@ -78,7 +75,10 @@ class PacketTracer: elif l2cap_pdu.cid == SMP_CID: smp_command = SMP_Command.from_bytes(l2cap_pdu.payload) self.analyzer.emit(smp_command) - elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID: + elif ( + l2cap_pdu.cid == L2CAP_SIGNALING_CID + or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID + ): control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload) self.analyzer.emit(control_frame) @@ -86,7 +86,10 @@ class PacketTracer: if control_frame.code == L2CAP_CONNECTION_REQUEST: self.psms[control_frame.source_cid] = control_frame.psm elif control_frame.code == L2CAP_CONNECTION_RESPONSE: - if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: + if ( + control_frame.result + == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL + ): if self.peer: if psm := self.peer.psms.get(control_frame.source_cid): # Found a pending connection @@ -94,8 +97,14 @@ class PacketTracer: # For AVDTP connections, create a packet assembler for each direction if psm == AVDTP_PSM: - self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message) - self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message) + self.avdtp_assemblers[ + control_frame.source_cid + ] = AVDTP_MessageAssembler(self.on_avdtp_message) + self.peer.avdtp_assemblers[ + control_frame.destination_cid + ] = AVDTP_MessageAssembler( + self.peer.on_avdtp_message + ) else: # Try to find the PSM associated with this PDU @@ -107,31 +116,39 @@ class PacketTracer: rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload) self.analyzer.emit(rfcomm_frame) elif psm == AVDTP_PSM: - self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}') + self.analyzer.emit( + f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}' + ) assembler = self.avdtp_assemblers.get(l2cap_pdu.cid) if assembler: assembler.on_pdu(l2cap_pdu.payload) else: psm_string = name_or_number(PSM_NAMES, psm) - self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}') + self.analyzer.emit( + f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}' + ) else: self.analyzer.emit(l2cap_pdu) def on_avdtp_message(self, transaction_label, message): - self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}') + self.analyzer.emit( + f'{color("AVDTP", "green")} [{transaction_label}] {message}' + ) def feed_packet(self, packet): self.packet_assembler.feed_packet(packet) class Analyzer: def __init__(self, label, emit_message): - self.label = label + self.label = label self.emit_message = emit_message - self.acl_streams = {} # ACL streams, by connection handle - self.peer = None # Analyzer in the other direction + self.acl_streams = {} # ACL streams, by connection handle + self.peer = None # Analyzer in the other direction def start_acl_stream(self, connection_handle): - logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}') + logger.info( + f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}' + ) stream = PacketTracer.AclStream(self) self.acl_streams[connection_handle] = stream @@ -144,7 +161,9 @@ class PacketTracer: def end_acl_stream(self, connection_handle): if connection_handle in self.acl_streams: - logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}') + logger.info( + f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}' + ) del self.acl_streams[connection_handle] # Let the other forwarder know so it can cleanup its stream as well @@ -176,9 +195,13 @@ class PacketTracer: self, host_to_controller_label=color('HOST->CONTROLLER', 'blue'), controller_to_host_label=color('CONTROLLER->HOST', 'cyan'), - emit_message=logger.info + emit_message=logger.info, ): - self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message) - self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message) + self.host_to_controller_analyzer = PacketTracer.Analyzer( + host_to_controller_label, emit_message + ) + self.controller_to_host_analyzer = PacketTracer.Analyzer( + controller_to_host_label, emit_message + ) self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer diff --git a/bumble/hfp.py b/bumble/hfp.py index 6eeb0d9a..f659aa80 100644 --- a/bumble/hfp.py +++ b/bumble/hfp.py @@ -34,9 +34,9 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- class HfpProtocol: def __init__(self, dlc): - self.dlc = dlc - self.buffer = '' - self.lines = collections.deque() + self.dlc = dlc + self.buffer = '' + self.lines = collections.deque() self.lines_available = asyncio.Event() dlc.sink = self.feed @@ -52,7 +52,7 @@ class HfpProtocol: self.buffer += data while (separator := self.buffer.find('\r')) >= 0: line = self.buffer[:separator].strip() - self.buffer = self.buffer[separator + 1:] + self.buffer = self.buffer[separator + 1 :] if len(line) > 0: self.on_line(line) @@ -79,16 +79,16 @@ class HfpProtocol: async def initialize_service(self): # Perform Service Level Connection Initialization self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features - line = await(self.next_line()) - line = await(self.next_line()) + line = await (self.next_line()) + line = await (self.next_line()) self.send_command_line('AT+CIND=?') - line = await(self.next_line()) - line = await(self.next_line()) + line = await (self.next_line()) + line = await (self.next_line()) self.send_command_line('AT+CIND?') - line = await(self.next_line()) - line = await(self.next_line()) + line = await (self.next_line()) + line = await (self.next_line()) self.send_command_line('AT+CMER=3,0,0,1') - line = await(self.next_line()) + line = await (self.next_line()) diff --git a/bumble/host.py b/bumble/host.py index ae4cc666..354d5fbe 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -36,21 +36,25 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 +# fmt: on + # ----------------------------------------------------------------------------- class Connection: def __init__(self, host, handle, role, peer_address, transport): - self.host = host - self.handle = handle - self.role = role - self.peer_address = peer_address - self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) - self.transport = transport + self.host = host + self.handle = handle + self.role = role + self.peer_address = peer_address + self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) + self.transport = transport def on_hci_acl_data_packet(self, packet): self.assembler.feed_packet(packet) @@ -62,29 +66,29 @@ class Connection: # ----------------------------------------------------------------------------- class Host(EventEmitter): - def __init__(self, controller_source = None, controller_sink = None): + def __init__(self, controller_source=None, controller_sink=None): super().__init__() - self.hci_sink = None - self.ready = False # True when we can accept incoming packets - self.connections = {} # Connections, by connection handle - self.pending_command = None - self.pending_response = None - self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH + self.hci_sink = None + self.ready = False # True when we can accept incoming packets + self.connections = {} # Connections, by connection handle + self.pending_command = None + self.pending_response = None + self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS - self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH - self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS - self.acl_packet_queue = collections.deque() - self.acl_packets_in_flight = 0 - self.local_version = None - self.local_supported_commands = bytes(64) - self.local_le_features = 0 - self.suggested_max_tx_octets = 251 # Max allowed - self.suggested_max_tx_time = 2120 # Max allowed - self.command_semaphore = asyncio.Semaphore(1) - self.long_term_key_provider = None - self.link_key_provider = None - self.pairing_io_capability_provider = None # Classic only + self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH + self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS + self.acl_packet_queue = collections.deque() + self.acl_packets_in_flight = 0 + self.local_version = None + self.local_supported_commands = bytes(64) + self.local_le_features = 0 + self.suggested_max_tx_octets = 251 # Max allowed + self.suggested_max_tx_time = 2120 # Max allowed + self.command_semaphore = asyncio.Semaphore(1) + self.long_term_key_provider = None + self.link_key_provider = None + self.pairing_io_capability_provider = None # Classic only # Connect to the source and sink if specified if controller_source: @@ -96,30 +100,51 @@ class Host(EventEmitter): await self.send_command(HCI_Reset_Command(), check_result=True) self.ready = True - response = await self.send_command(HCI_Read_Local_Supported_Commands_Command(), check_result=True) + response = await self.send_command( + HCI_Read_Local_Supported_Commands_Command(), check_result=True + ) self.local_supported_commands = response.return_parameters.supported_commands if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): - response = await self.send_command(HCI_LE_Read_Local_Supported_Features_Command(), check_result=True) - self.local_le_features = struct.unpack(' CONTROLLER", "blue")}: (CID={cid}) {acl_packet}' ) - logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}') self.queue_acl_packet(acl_packet) pb_flag = 1 offset += data_total_length @@ -251,11 +293,16 @@ class Host(EventEmitter): self.check_acl_packet_queue() if len(self.acl_packet_queue): - logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue') + logger.debug( + f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue' + ) def check_acl_packet_queue(self): # Send all we can (TODO: support different LE/Classic limits) - while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets: + while ( + len(self.acl_packet_queue) > 0 + and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets + ): packet = self.acl_packet_queue.pop() self.send_hci_packet(packet) self.acl_packets_in_flight += 1 @@ -267,7 +314,9 @@ class Host(EventEmitter): if value == command: # Check if the flag is set if octet < len(self.local_supported_commands) and flag_position < 8: - return (self.local_supported_commands[octet] & (1 << flag_position)) != 0 + return ( + self.local_supported_commands[octet] & (1 << flag_position) + ) != 0 return False @@ -289,15 +338,17 @@ class Host(EventEmitter): @property def supported_le_features(self): - return [feature for feature in range(64) if self.local_le_features & (1 << feature)] + return [ + feature for feature in range(64) if self.local_le_features & (1 << feature) + ] # Packet Sink protocol (packets coming from the controller via HCI) def on_packet(self, packet): hci_packet = HCI_Packet.from_bytes(packet) if self.ready or ( - hci_packet.hci_packet_type == HCI_EVENT_PACKET and - hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and - hci_packet.command_opcode == HCI_RESET_COMMAND + hci_packet.hci_packet_type == HCI_EVENT_PACKET + and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT + and hci_packet.command_opcode == HCI_RESET_COMMAND ): self.on_hci_packet(hci_packet) else: @@ -336,7 +387,9 @@ class Host(EventEmitter): if self.pending_response: # Check that it is what we were expecting if self.pending_command.op_code != event.command_opcode: - logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}') + logger.warning( + f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}' + ) self.pending_response.set_result(event) else: @@ -364,7 +417,11 @@ class Host(EventEmitter): self.acl_packets_in_flight -= total_packets self.check_acl_packet_queue() else: - logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight')) + logger.warning( + color( + f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight' + ) + ) self.acl_packets_in_flight = 0 # Classic only @@ -381,18 +438,26 @@ class Host(EventEmitter): # Check if this is a cancellation if event.status == HCI_SUCCESS: # Create/update the connection - logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}') + logger.debug( + f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}' + ) connection = self.connections.get(event.connection_handle) if connection is None: - connection = Connection(self, event.connection_handle, event.role, event.peer_address, BT_LE_TRANSPORT) + connection = Connection( + self, + event.connection_handle, + event.role, + event.peer_address, + BT_LE_TRANSPORT, + ) self.connections[event.connection_handle] = connection # Notify the client connection_parameters = ConnectionParameters( event.connection_interval, event.peripheral_latency, - event.supervision_timeout + event.supervision_timeout, ) self.emit( 'connection', @@ -401,13 +466,15 @@ class Host(EventEmitter): event.peer_address, None, event.role, - connection_parameters + connection_parameters, ) else: logger.debug(f'### CONNECTION FAILED: {event.status}') # Notify the listeners - self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status) + self.emit( + 'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status + ) def on_hci_le_enhanced_connection_complete_event(self, event): # Just use the same implementation as for the non-enhanced event for now @@ -416,11 +483,19 @@ class Host(EventEmitter): def on_hci_connection_complete_event(self, event): if event.status == HCI_SUCCESS: # Create/update the connection - logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}') + logger.debug( + f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}' + ) connection = self.connections.get(event.connection_handle) if connection is None: - connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr, BT_BR_EDR_TRANSPORT) + connection = Connection( + self, + event.connection_handle, + BT_CENTRAL_ROLE, + event.bd_addr, + BT_BR_EDR_TRANSPORT, + ) self.connections[event.connection_handle] = connection # Notify the client @@ -431,13 +506,15 @@ class Host(EventEmitter): event.bd_addr, None, BT_CENTRAL_ROLE, - None + None, ) else: logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') # Notify the client - self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status) + self.emit( + 'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status + ) def on_hci_disconnection_complete_event(self, event): # Find the connection @@ -446,7 +523,9 @@ class Host(EventEmitter): return if event.status == HCI_SUCCESS: - logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}') + logger.debug( + f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}' + ) del self.connections[event.connection_handle] # Notify the listeners @@ -467,11 +546,15 @@ class Host(EventEmitter): connection_parameters = ConnectionParameters( event.connection_interval, event.peripheral_latency, - event.supervision_timeout + event.supervision_timeout, + ) + self.emit( + 'connection_parameters_update', connection.handle, connection_parameters ) - self.emit('connection_parameters_update', connection.handle, connection_parameters) else: - self.emit('connection_parameters_update_failure', connection.handle, event.status) + self.emit( + 'connection_parameters_update_failure', connection.handle, event.status + ) def on_hci_le_phy_update_complete_event(self, event): if (connection := self.connections.get(event.connection_handle)) is None: @@ -501,13 +584,13 @@ class Host(EventEmitter): # TODO: delegate the decision self.send_command_sync( HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( - connection_handle = event.connection_handle, - interval_min = event.interval_min, - interval_max = event.interval_max, - latency = event.latency, - timeout = event.timeout, - min_ce_length = 0, - max_ce_length = 0 + connection_handle=event.connection_handle, + interval_min=event.interval_min, + interval_max=event.interval_max, + latency=event.latency, + timeout=event.timeout, + min_ce_length=0, + max_ce_length=0, ) ) @@ -522,18 +605,16 @@ class Host(EventEmitter): long_term_key = None else: long_term_key = await self.long_term_key_provider( - connection.handle, - event.random_number, - event.encryption_diversifier + connection.handle, event.random_number, event.encryption_diversifier ) if long_term_key: response = HCI_LE_Long_Term_Key_Request_Reply_Command( - connection_handle = event.connection_handle, - long_term_key = long_term_key + connection_handle=event.connection_handle, + long_term_key=long_term_key, ) else: response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( - connection_handle = event.connection_handle + connection_handle=event.connection_handle ) await self.send_command(response) @@ -548,10 +629,14 @@ class Host(EventEmitter): def on_hci_role_change_event(self, event): if event.status == HCI_SUCCESS: - logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}') + logger.debug( + f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}' + ) # TODO: lookup the connection and update the role else: - logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}') + logger.debug( + f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}' + ) def on_hci_le_data_length_change_event(self, event): self.emit( @@ -560,7 +645,7 @@ class Host(EventEmitter): event.max_tx_octets, event.max_tx_time, event.max_rx_octets, - event.max_rx_time + event.max_rx_time, ) def on_hci_authentication_complete_event(self, event): @@ -568,21 +653,35 @@ class Host(EventEmitter): if event.status == HCI_SUCCESS: self.emit('connection_authentication', event.connection_handle) else: - self.emit('connection_authentication_failure', event.connection_handle, event.status) + self.emit( + 'connection_authentication_failure', + event.connection_handle, + event.status, + ) def on_hci_encryption_change_event(self, event): # Notify the client if event.status == HCI_SUCCESS: - self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled) + self.emit( + 'connection_encryption_change', + event.connection_handle, + event.encryption_enabled, + ) else: - self.emit('connection_encryption_failure', event.connection_handle, event.status) + self.emit( + 'connection_encryption_failure', event.connection_handle, event.status + ) def on_hci_encryption_key_refresh_complete_event(self, event): # Notify the client if event.status == HCI_SUCCESS: self.emit('connection_encryption_key_refresh', event.connection_handle) else: - self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status) + self.emit( + 'connection_encryption_key_refresh_failure', + event.connection_handle, + event.status, + ) def on_hci_link_supervision_timeout_changed_event(self, event): pass @@ -594,11 +693,15 @@ class Host(EventEmitter): pass def on_hci_link_key_notification_event(self, event): - logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}') + logger.debug( + f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}' + ) self.emit('link_key', event.bd_addr, event.link_key, event.key_type) def on_hci_simple_pairing_complete_event(self, event): - logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') + logger.debug( + f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}' + ) # Notify the client if event.status == HCI_SUCCESS: self.emit('ssp_complete', event.bd_addr) @@ -607,9 +710,7 @@ class Host(EventEmitter): # For now, just refuse all requests # TODO: delegate the decision self.send_command_sync( - HCI_PIN_Code_Request_Negative_Reply_Command( - bd_addr = event.bd_addr - ) + HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr) ) def on_hci_link_key_request_event(self, event): @@ -621,12 +722,11 @@ class Host(EventEmitter): link_key = await self.link_key_provider(event.bd_addr) if link_key: response = HCI_Link_Key_Request_Reply_Command( - bd_addr = event.bd_addr, - link_key = link_key + bd_addr=event.bd_addr, link_key=link_key ) else: response = HCI_Link_Key_Request_Negative_Reply_Command( - bd_addr = event.bd_addr + bd_addr=event.bd_addr ) await self.send_command(response) @@ -640,13 +740,19 @@ class Host(EventEmitter): pass def on_hci_user_confirmation_request_event(self, event): - self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value) + self.emit( + 'authentication_user_confirmation_request', + event.bd_addr, + event.numeric_value, + ) def on_hci_user_passkey_request_event(self, event): self.emit('authentication_user_passkey_request', event.bd_addr) def on_hci_user_passkey_notification_event(self, event): - self.emit('authentication_user_passkey_notification', event.bd_addr, event.passkey) + self.emit( + 'authentication_user_passkey_notification', event.bd_addr, event.passkey + ) def on_hci_inquiry_complete_event(self, event): self.emit('inquiry_complete') @@ -658,7 +764,7 @@ class Host(EventEmitter): response.bd_addr, response.class_of_device, b'', - response.rssi + response.rssi, ) def on_hci_extended_inquiry_result_event(self, event): @@ -667,7 +773,7 @@ class Host(EventEmitter): event.bd_addr, event.class_of_device, event.extended_inquiry_response, - event.rssi + event.rssi, ) def on_hci_remote_name_request_complete_event(self, event): @@ -677,4 +783,8 @@ class Host(EventEmitter): self.emit('remote_name', event.bd_addr, event.remote_name) def on_hci_remote_host_supported_features_notification_event(self, event): - self.emit('remote_host_supported_features', event.bd_addr, event.host_supported_features) + self.emit( + 'remote_host_supported_features', + event.bd_addr, + event.host_supported_features, + ) diff --git a/bumble/keys.py b/bumble/keys.py index b8c05b48..cbb58ca6 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -39,10 +39,10 @@ logger = logging.getLogger(__name__) class PairingKeys: class Key: def __init__(self, value, authenticated=False, ediv=None, rand=None): - self.value = value + self.value = value self.authenticated = authenticated - self.ediv = ediv - self.rand = rand + self.ediv = ediv + self.rand = rand @classmethod def from_dict(cls, key_dict): @@ -65,13 +65,13 @@ class PairingKeys: return key_dict def __init__(self): - self.address_type = None - self.ltk = None - self.ltk_central = None + self.address_type = None + self.ltk = None + self.ltk_central = None self.ltk_peripheral = None - self.irk = None - self.csrk = None - self.link_key = None # Classic + self.irk = None + self.csrk = None + self.link_key = None # Classic @staticmethod def key_from_dict(keys_dict, key_name): @@ -83,13 +83,13 @@ class PairingKeys: def from_dict(keys_dict): keys = PairingKeys() - keys.address_type = keys_dict.get('address_type') - keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') - keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') + keys.address_type = keys_dict.get('address_type') + keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') + keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') - keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') - keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') - keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') + keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') + keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') + keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') return keys @@ -166,7 +166,7 @@ class KeyStore: separator = '' for (name, keys) in entries: print(separator + prefix + color(name, 'yellow')) - keys.print(prefix = prefix + ' ') + keys.print(prefix=prefix + ' ') separator = '\n' @staticmethod @@ -183,9 +183,9 @@ class KeyStore: # ----------------------------------------------------------------------------- class JsonKeyStore(KeyStore): - APP_NAME = 'Bumble' - APP_AUTHOR = 'Google' - KEYS_DIR = 'Pairing' + APP_NAME = 'Bumble' + APP_AUTHOR = 'Google' + KEYS_DIR = 'Pairing' DEFAULT_NAMESPACE = '__DEFAULT__' def __init__(self, namespace, filename=None): @@ -194,9 +194,9 @@ class JsonKeyStore(KeyStore): if filename is None: # Use a default for the current user import appdirs + self.directory_name = os.path.join( - appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), - self.KEYS_DIR + appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR ) json_filename = f'{self.namespace}.json'.lower().replace(':', '-') self.filename = os.path.join(self.directory_name, json_filename) @@ -262,7 +262,9 @@ class JsonKeyStore(KeyStore): if namespace is None: return [] - return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()] + return [ + (name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items() + ] async def delete_all(self): db = await self.load() diff --git a/bumble/l2cap.py b/bumble/l2cap.py index c61a45fd..9f438f1d 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -24,8 +24,12 @@ from colors import color from pyee import EventEmitter from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError -from .hci import (HCI_LE_Connection_Update_Command, HCI_Object, key_with_value, - name_or_number) +from .hci import ( + HCI_LE_Connection_Update_Command, + HCI_Object, + key_with_value, + name_or_number, +) # ----------------------------------------------------------------------------- # Logging @@ -36,6 +40,8 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + L2CAP_SIGNALING_CID = 0x01 L2CAP_LE_SIGNALING_CID = 0x05 @@ -130,6 +136,8 @@ L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01 L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01 +# fmt: on + # ----------------------------------------------------------------------------- # Classes @@ -155,7 +163,7 @@ class L2CAP_PDU: return header + self.payload def __init__(self, cid, payload): - self.cid = cid + self.cid = cid self.payload = payload def __bytes__(self): @@ -170,6 +178,7 @@ class L2CAP_Control_Frame: ''' See Bluetooth spec @ Vol 3, Part A - 4 SIGNALING PACKET FORMATS ''' + classes = {} code = 0 @@ -188,7 +197,12 @@ class L2CAP_Control_Frame: self.identifier = pdu[1] length = struct.unpack_from('= 2: - type = data[0] + type = data[0] length = data[1] - value = data[2:2 + length] - data = data[2 + length:] + value = data[2 : 2 + length] + data = data[2 + length :] options.append((type, value)) return options @staticmethod def encode_configuration_options(options): - return b''.join([bytes([option[0], len(option[1])]) + option[1] for option in options]) + return b''.join( + [bytes([option[0], len(option[1])]) + option[1] for option in options] + ) @staticmethod def subclass(fields): @@ -219,7 +235,9 @@ class L2CAP_Control_Frame: cls.name = cls.__name__.upper() cls.code = key_with_value(L2CAP_CONTROL_FRAME_NAMES, cls.name) if cls.code is None: - raise KeyError(f'Control Frame name {cls.name} not found in L2CAP_CONTROL_FRAME_NAMES') + raise KeyError( + f'Control Frame name {cls.name} not found in L2CAP_CONTROL_FRAME_NAMES' + ) cls.fields = fields # Register a factory for this class @@ -235,7 +253,11 @@ class L2CAP_Control_Frame: HCI_Object.init_from_fields(self, self.fields, kwargs) if pdu is None: data = HCI_Object.dict_to_bytes(kwargs, self.fields) - pdu = bytes([self.code, self.identifier]) + struct.pack(' {color(Channel.STATE_NAMES[new_state], "cyan")}') + logger.debug( + f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}' + ) self.state = new_state def send_pdu(self, pdu): @@ -681,7 +725,9 @@ class Channel(EventEmitter): elif self.sink: self.sink(pdu) else: - logger.warning(color('received pdu without a pending request or sink', 'red')) + logger.warning( + color('received pdu without a pending request or sink', 'red') + ) async def connect(self): if self.state != Channel.CLOSED: @@ -694,9 +740,9 @@ class Channel(EventEmitter): self.change_state(Channel.WAIT_CONNECT_RSP) self.send_control_frame( L2CAP_Connection_Request( - identifier = self.manager.next_identifier(self.connection), - psm = self.psm, - source_cid = self.source_cid + identifier=self.manager.next_identifier(self.connection), + psm=self.psm, + source_cid=self.source_cid, ) ) @@ -716,9 +762,9 @@ class Channel(EventEmitter): self.change_state(Channel.WAIT_DISCONNECT) self.send_control_frame( L2CAP_Disconnection_Request( - identifier = self.manager.next_identifier(self.connection), - destination_cid = self.destination_cid, - source_cid = self.source_cid + identifier=self.manager.next_identifier(self.connection), + destination_cid=self.destination_cid, + source_cid=self.source_cid, ) ) @@ -727,16 +773,20 @@ class Channel(EventEmitter): return await self.disconnection_result def send_configure_request(self): - options = L2CAP_Control_Frame.encode_configuration_options([( - L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE, - struct.pack(' {color(self.state_name(new_state), "cyan")}') + logger.debug( + f'{self} state change -> {color(self.state_name(new_state), "cyan")}' + ) self.state = new_state if new_state == self.CONNECTED: @@ -975,12 +1041,12 @@ class LeConnectionOrientedChannel(EventEmitter): self.change_state(self.CONNECTING) request = L2CAP_LE_Credit_Based_Connection_Request( - identifier = identifier, - le_psm = self.le_psm, - source_cid = self.source_cid, - mtu = self.mtu, - mps = self.mps, - initial_credits = self.peer_credits + identifier=identifier, + le_psm=self.le_psm, + source_cid=self.source_cid, + mtu=self.mtu, + mps=self.mps, + initial_credits=self.peer_credits, ) self.manager.le_coc_requests[identifier] = request self.send_control_frame(request) @@ -1000,9 +1066,9 @@ class LeConnectionOrientedChannel(EventEmitter): self.flush_output() self.send_control_frame( L2CAP_Disconnection_Request( - identifier = self.manager.next_identifier(self.connection), - destination_cid = self.destination_cid, - source_cid = self.source_cid + identifier=self.manager.next_identifier(self.connection), + destination_cid=self.destination_cid, + source_cid=self.source_cid, ) ) @@ -1027,9 +1093,9 @@ class LeConnectionOrientedChannel(EventEmitter): # The credits fell below the threshold, replenish them to the max self.send_control_frame( L2CAP_LE_Flow_Control_Credit( - identifier = self.manager.next_identifier(self.connection), - cid = self.source_cid, - credits = self.peer_max_credits - self.peer_credits + identifier=self.manager.next_identifier(self.connection), + cid=self.source_cid, + credits=self.peer_max_credits - self.peer_credits, ) ) self.peer_credits = self.peer_max_credits @@ -1052,11 +1118,15 @@ class LeConnectionOrientedChannel(EventEmitter): return if len(self.in_sdu) < 2 + self.in_sdu_length: # Not complete yet - logger.debug(f'SDU: {len(self.in_sdu) - 2} of {self.in_sdu_length} bytes received') + logger.debug( + f'SDU: {len(self.in_sdu) - 2} of {self.in_sdu_length} bytes received' + ) return if len(self.in_sdu) != 2 + self.in_sdu_length: # Overflow - logger.warning(f'SDU overflow: sdu_length={self.in_sdu_length}, received {len(self.in_sdu) - 2}') + logger.warning( + f'SDU overflow: sdu_length={self.in_sdu_length}, received {len(self.in_sdu) - 2}' + ) # TODO: we should disconnect self.in_sdu = None self.in_sdu_length = 0 @@ -1073,15 +1143,20 @@ class LeConnectionOrientedChannel(EventEmitter): def on_connection_response(self, response): # Look for a matching pending response result if self.connection_result is None: - logger.warning(f'received unexpected connection response (id={response.identifier})') + logger.warning( + f'received unexpected connection response (id={response.identifier})' + ) return - if response.result == L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL: + if ( + response.result + == L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL + ): self.destination_cid = response.destination_cid - self.peer_mtu = response.mtu - self.peer_mps = response.mps - self.credits = response.initial_credits - self.connected = True + self.peer_mtu = response.mtu + self.peer_mps = response.mps + self.credits = response.initial_credits + self.connected = True self.connection_result.set_result(self) self.change_state(self.CONNECTED) else: @@ -1089,7 +1164,10 @@ class LeConnectionOrientedChannel(EventEmitter): ProtocolError( response.result, 'l2cap', - L2CAP_LE_Credit_Based_Connection_Response.result_name(response.result)) + L2CAP_LE_Credit_Based_Connection_Response.result_name( + response.result + ), + ) ) self.change_state(self.CONNECTION_ERROR) @@ -1106,9 +1184,9 @@ class LeConnectionOrientedChannel(EventEmitter): def on_disconnection_request(self, request): self.send_control_frame( L2CAP_Disconnection_Response( - identifier = request.identifier, - destination_cid = request.destination_cid, - source_cid = request.source_cid + identifier=request.identifier, + destination_cid=request.destination_cid, + source_cid=request.source_cid, ) ) self.change_state(self.DISCONNECTED) @@ -1119,7 +1197,10 @@ class LeConnectionOrientedChannel(EventEmitter): logger.warning(color('invalid state', 'red')) return - if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid: + if ( + response.destination_cid != self.destination_cid + or response.source_cid != self.source_cid + ): logger.warning('unexpected source or destination CID') return @@ -1136,7 +1217,7 @@ class LeConnectionOrientedChannel(EventEmitter): while self.credits > 0: if self.out_sdu is not None: # Finish the current SDU - packet = self.out_sdu[:self.peer_mps] + packet = self.out_sdu[: self.peer_mps] self.send_pdu(packet) self.credits -= 1 logger.debug(f'sent {len(packet)} bytes, {self.credits} credits left') @@ -1145,21 +1226,25 @@ class LeConnectionOrientedChannel(EventEmitter): self.out_sdu = None else: # Keep what's still left to send - self.out_sdu = self.out_sdu[len(packet):] + self.out_sdu = self.out_sdu[len(packet) :] continue elif self.out_queue: # Create the next SDU (2 bytes header plus up to MTU bytes payload) - logger.debug(f'assembling SDU from {len(self.out_queue)} packets in output queue') + logger.debug( + f'assembling SDU from {len(self.out_queue)} packets in output queue' + ) payload = b'' while self.out_queue and len(payload) < self.peer_mtu: # We can add more data to the payload - chunk = self.out_queue[0][:self.peer_mtu - len(payload)] + chunk = self.out_queue[0][: self.peer_mtu - len(payload)] payload += chunk - self.out_queue[0] = self.out_queue[0][len(chunk):] + self.out_queue[0] = self.out_queue[0][len(chunk) :] if len(self.out_queue[0]) == 0: # We consumed the entire buffer, remove it self.out_queue.popleft() - logger.debug(f'packet completed, {len(self.out_queue)} left in queue') + logger.debug( + f'packet completed, {len(self.out_queue)} left in queue' + ) # Construct the SDU with its header assert len(payload) != 0 @@ -1178,7 +1263,9 @@ class LeConnectionOrientedChannel(EventEmitter): # Queue the data self.out_queue.append(data) self.drained.clear() - logger.debug(f'{len(data)} bytes packet queued, {len(self.out_queue)} packets in queue') + logger.debug( + f'{len(data)} bytes packet queued, {len(self.out_queue)} packets in queue' + ) # Send what we can self.process_output() @@ -1200,18 +1287,23 @@ class LeConnectionOrientedChannel(EventEmitter): # ----------------------------------------------------------------------------- class ChannelManager: - def __init__(self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU): - self._host = None - self.identifiers = {} # Incrementing identifier values by connection - self.channels = {} # All channels, mapped by connection and source cid - self.fixed_channels = { # Fixed channel handlers, mapped by cid - L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None + def __init__( + self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU + ): + self._host = None + self.identifiers = {} # Incrementing identifier values by connection + self.channels = {} # All channels, mapped by connection and source cid + self.fixed_channels = { # Fixed channel handlers, mapped by cid + L2CAP_SIGNALING_CID: None, + L2CAP_LE_SIGNALING_CID: None, } - self.servers = {} # Servers accepting connections, by PSM - self.le_coc_channels = {} # LE CoC channels, mapped by connection and destination cid - self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM - self.le_coc_requests = {} # LE CoC connection requests, by identifier - self.extended_features = extended_features + self.servers = {} # Servers accepting connections, by PSM + self.le_coc_channels = ( + {} + ) # LE CoC channels, mapped by connection and destination cid + self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM + self.le_coc_requests = {} # LE CoC connection requests, by identifier + self.extended_features = extended_features self.connectionless_mtu = connectionless_mtu @property @@ -1239,7 +1331,9 @@ class ChannelManager: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) - for cid in range(L2CAP_ACL_U_DYNAMIC_CID_RANGE_START, L2CAP_ACL_U_DYNAMIC_CID_RANGE_END + 1): + for cid in range( + L2CAP_ACL_U_DYNAMIC_CID_RANGE_START, L2CAP_ACL_U_DYNAMIC_CID_RANGE_END + 1 + ): if cid not in channels: return cid @@ -1248,17 +1342,25 @@ class ChannelManager: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) - for cid in range(L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1): + for cid in range( + L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1 + ): if cid not in channels: return cid @staticmethod def check_le_coc_parameters(max_credits, mtu, mps): - if max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS: + if ( + max_credits < 1 + or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS + ): raise ValueError('max credits out of range') if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU: raise ValueError('MTU too small') - if mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS: + if ( + mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS + or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS + ): raise ValueError('MPS out of range') def next_identifier(self, connection): @@ -1276,7 +1378,9 @@ class ChannelManager: def register_server(self, psm, server): if psm == 0: # Find a free PSM - for candidate in range(L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2): + for candidate in range( + L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 + ): if (candidate >> 8) % 2 == 1: continue if candidate in self.servers: @@ -1309,13 +1413,15 @@ class ChannelManager: server, max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, ): self.check_le_coc_parameters(max_credits, mtu, mps) if psm == 0: # Find a free PSM - for candidate in range(L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1): + for candidate in range( + L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 + ): if candidate in self.le_coc_servers: continue psm = candidate @@ -1331,7 +1437,7 @@ class ChannelManager: server, max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, ) return psm @@ -1347,7 +1453,9 @@ class ChannelManager: def send_pdu(self, connection, cid, pdu): pdu_str = pdu.hex() if type(pdu) is bytes else str(pdu) - logger.debug(f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}') + logger.debug( + f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}' + ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) def on_pdu(self, connection, cid, pdu): @@ -1360,17 +1468,25 @@ class ChannelManager: self.fixed_channels[cid](connection.handle, pdu) else: if (channel := self.find_channel(connection.handle, cid)) is None: - logger.warning(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning( + color( + f'channel not found for 0x{connection.handle:04X}:{cid}', 'red' + ) + ) return channel.on_pdu(pdu) def send_control_frame(self, connection, cid, control_frame): - logger.debug(f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}') + logger.debug( + f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}' + ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) def on_control_frame(self, connection, cid, control_frame): - logger.debug(f'{color("<<< Received L2CAP Signaling Control Frame", "green")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}') + logger.debug( + f'{color("<<< Received L2CAP Signaling Control Frame", "green")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}' + ) # Find the handler method handler_name = f'on_{control_frame.name.lower()}' @@ -1384,10 +1500,10 @@ class ChannelManager: connection, cid, L2CAP_Command_Reject( - identifier = control_frame.identifier, - reason = L2CAP_COMMAND_NOT_UNDERSTOOD_REASON, - data = b'' - ) + identifier=control_frame.identifier, + reason=L2CAP_COMMAND_NOT_UNDERSTOOD_REASON, + data=b'', + ), ) raise error else: @@ -1396,10 +1512,10 @@ class ChannelManager: connection, cid, L2CAP_Command_Reject( - identifier = control_frame.identifier, - reason = L2CAP_COMMAND_NOT_UNDERSTOOD_REASON, - data = b'' - ) + identifier=control_frame.identifier, + reason=L2CAP_COMMAND_NOT_UNDERSTOOD_REASON, + data=b'', + ), ) def on_l2cap_command_reject(self, connection, cid, packet): @@ -1417,68 +1533,109 @@ class ChannelManager: connection, cid, L2CAP_Connection_Response( - identifier = request.identifier, - destination_cid = request.source_cid, - source_cid = 0, - result = L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, - status = 0x0000 - ) + identifier=request.identifier, + destination_cid=request.source_cid, + source_cid=0, + result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + status=0x0000, + ), ) return # Create a new channel - logger.debug(f'creating server channel with cid={source_cid} for psm {request.psm}') - channel = Channel(self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU) + logger.debug( + f'creating server channel with cid={source_cid} for psm {request.psm}' + ) + channel = Channel( + self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU + ) connection_channels[source_cid] = channel # Notify server(channel) channel.on_connection_request(request) else: - logger.warning(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}') + logger.warning( + f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}' + ) self.send_control_frame( connection, cid, L2CAP_Connection_Response( - identifier = request.identifier, - destination_cid = request.source_cid, - source_cid = 0, - result = L2CAP_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, - status = 0x0000 - ) + identifier=request.identifier, + destination_cid=request.source_cid, + source_cid=0, + result=L2CAP_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, + status=0x0000, + ), ) def on_l2cap_connection_response(self, connection, cid, response): - if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + if ( + channel := self.find_channel(connection.handle, response.source_cid) + ) is None: + logger.warning( + color( + f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', + 'red', + ) + ) return channel.on_connection_response(response) def on_l2cap_configure_request(self, connection, cid, request): - if (channel := self.find_channel(connection.handle, request.destination_cid)) is None: - logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + if ( + channel := self.find_channel(connection.handle, request.destination_cid) + ) is None: + logger.warning( + color( + f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', + 'red', + ) + ) return channel.on_configure_request(request) def on_l2cap_configure_response(self, connection, cid, response): - if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + if ( + channel := self.find_channel(connection.handle, response.source_cid) + ) is None: + logger.warning( + color( + f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', + 'red', + ) + ) return channel.on_configure_response(response) def on_l2cap_disconnection_request(self, connection, cid, request): - if (channel := self.find_channel(connection.handle, request.destination_cid)) is None: - logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + if ( + channel := self.find_channel(connection.handle, request.destination_cid) + ) is None: + logger.warning( + color( + f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', + 'red', + ) + ) return channel.on_disconnection_request(request) def on_l2cap_disconnection_response(self, connection, cid, response): - if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + if ( + channel := self.find_channel(connection.handle, response.source_cid) + ) is None: + logger.warning( + color( + f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', + 'red', + ) + ) return channel.on_disconnection_response(response) @@ -1488,10 +1645,7 @@ class ChannelManager: self.send_control_frame( connection, cid, - L2CAP_Echo_Response( - identifier = request.identifier, - data = request.data - ) + L2CAP_Echo_Response(identifier=request.identifier, data=request.data), ) def on_l2cap_echo_response(self, connection, cid, response): @@ -1515,11 +1669,11 @@ class ChannelManager: connection, cid, L2CAP_Information_Response( - identifier = request.identifier, - info_type = request.info_type, - result = result, - data = data - ) + identifier=request.identifier, + info_type=request.info_type, + result=result, + data=data, + ), ) def on_l2cap_connection_parameter_update_request(self, connection, cid, request): @@ -1528,27 +1682,29 @@ class ChannelManager: connection, cid, L2CAP_Connection_Parameter_Update_Response( - identifier = request.identifier, - result = L2CAP_CONNECTION_PARAMETERS_ACCEPTED_RESULT + identifier=request.identifier, + result=L2CAP_CONNECTION_PARAMETERS_ACCEPTED_RESULT, + ), + ) + self.host.send_command_sync( + HCI_LE_Connection_Update_Command( + connection_handle=connection.handle, + connection_interval_min=request.interval_min, + connection_interval_max=request.interval_max, + max_latency=request.latency, + supervision_timeout=request.timeout, + min_ce_length=0, + max_ce_length=0, ) ) - self.host.send_command_sync(HCI_LE_Connection_Update_Command( - connection_handle = connection.handle, - connection_interval_min = request.interval_min, - connection_interval_max = request.interval_max, - max_latency = request.latency, - supervision_timeout = request.timeout, - min_ce_length = 0, - max_ce_length = 0 - )) else: self.send_control_frame( connection, cid, L2CAP_Connection_Parameter_Update_Response( - identifier = request.identifier, - result = L2CAP_CONNECTION_PARAMETERS_REJECTED_RESULT - ) + identifier=request.identifier, + result=L2CAP_CONNECTION_PARAMETERS_REJECTED_RESULT, + ), ) def on_l2cap_connection_parameter_update_response(self, connection, cid, response): @@ -1560,20 +1716,22 @@ class ChannelManager: (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] # Check that the CID isn't already used - le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + le_connection_channels = self.le_coc_channels.setdefault( + connection.handle, {} + ) if request.source_cid in le_connection_channels: logger.warning(f'source CID {request.source_cid} already in use') self.send_control_frame( connection, cid, L2CAP_LE_Credit_Based_Connection_Response( - identifier = request.identifier, - destination_cid = 0, - mtu = mtu, - mps = mps, - initial_credits = 0, - result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED - ) + identifier=request.identifier, + destination_cid=0, + mtu=mtu, + mps=mps, + initial_credits=0, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + ), ) return @@ -1585,18 +1743,20 @@ class ChannelManager: connection, cid, L2CAP_LE_Credit_Based_Connection_Response( - identifier = request.identifier, - destination_cid = 0, - mtu = mtu, - mps = mps, - initial_credits = 0, - result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, - ) + identifier=request.identifier, + destination_cid=0, + mtu=mtu, + mps=mps, + initial_credits=0, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + ), ) return # Create a new channel - logger.debug(f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}') + logger.debug( + f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}' + ) channel = LeConnectionOrientedChannel( self, connection, @@ -1609,7 +1769,7 @@ class ChannelManager: request.mtu, request.mps, max_credits, - True + True, ) connection_channels[source_cid] = channel le_connection_channels[request.source_cid] = channel @@ -1619,30 +1779,32 @@ class ChannelManager: connection, cid, L2CAP_LE_Credit_Based_Connection_Response( - identifier = request.identifier, - destination_cid = source_cid, - mtu = mtu, - mps = mps, - initial_credits = max_credits, - result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL - ) + identifier=request.identifier, + destination_cid=source_cid, + mtu=mtu, + mps=mps, + initial_credits=max_credits, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, + ), ) # Notify server(channel) else: - logger.info(f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}') + logger.info( + f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}' + ) self.send_control_frame( connection, cid, L2CAP_LE_Credit_Based_Connection_Response( - identifier = request.identifier, - destination_cid = 0, - mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - initial_credits = 0, - result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, - ) + identifier=request.identifier, + destination_cid=0, + mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + initial_credits=0, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, + ), ) def on_l2cap_le_credit_based_connection_response(self, connection, cid, response): @@ -1656,7 +1818,12 @@ class ChannelManager: # Find the channel for this request channel = self.find_channel(connection.handle, request.source_cid) if channel is None: - logger.warning(color(f'received connection response for an unknown channel (cid={request.source_cid})', 'red')) + logger.warning( + color( + f'received connection response for an unknown channel (cid={request.source_cid})', + 'red', + ) + ) return # Process the response @@ -1688,18 +1855,18 @@ class ChannelManager: # Create the channel logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}') channel = LeConnectionOrientedChannel( - manager = self, - connection = connection, - le_psm = psm, - source_cid = source_cid, - destination_cid = 0, - mtu = mtu, - mps = mps, - credits = 0, - peer_mtu = 0, - peer_mps = 0, - peer_credits = max_credits, - connected = False + manager=self, + connection=connection, + le_psm=psm, + source_cid=source_cid, + destination_cid=0, + mtu=mtu, + mps=mps, + credits=0, + peer_mtu=0, + peer_mps=0, + peer_credits=max_credits, + connected=False, ) connection_channels[source_cid] = channel @@ -1728,7 +1895,9 @@ class ChannelManager: # Create the channel logger.debug(f'creating client channel with cid={source_cid} for psm {psm}') - channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU) + channel = Channel( + self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU + ) connection_channels[source_cid] = channel # Connect diff --git a/bumble/link.py b/bumble/link.py index 4463e271..54649c6c 100644 --- a/bumble/link.py +++ b/bumble/link.py @@ -25,7 +25,7 @@ from bumble.hci import ( Address, HCI_SUCCESS, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, - HCI_CONNECTION_TIMEOUT_ERROR + HCI_CONNECTION_TIMEOUT_ERROR, ) # ----------------------------------------------------------------------------- @@ -55,7 +55,7 @@ class LocalLink: ''' def __init__(self): - self.controllers = set() + self.controllers = set() self.pending_connection = None def add_controller(self, controller): @@ -103,23 +103,30 @@ class LocalLink: return # Connect to the first controller with a matching address - if peripheral_controller := self.find_controller(le_create_connection_command.peer_address): - central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS) + if peripheral_controller := self.find_controller( + le_create_connection_command.peer_address + ): + central_controller.on_link_peripheral_connection_complete( + le_create_connection_command, HCI_SUCCESS + ) peripheral_controller.on_link_central_connected(central_address) return # No peripheral found central_controller.on_link_peripheral_connection_complete( - le_create_connection_command, - HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR + le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR ) def connect(self, central_address, le_create_connection_command): - logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}') + logger.debug( + f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}' + ) self.pending_connection = (central_address, le_create_connection_command) asyncio.get_running_loop().call_soon(self.on_connection_complete) - def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command): + def on_disconnection_complete( + self, central_address, peripheral_address, disconnect_command + ): # Find the controller that initiated the disconnection if not (central_controller := self.find_controller(central_address)): logger.warning('!!! Initiating controller not found') @@ -127,16 +134,24 @@ class LocalLink: # Disconnect from the first controller with a matching address if peripheral_controller := self.find_controller(peripheral_address): - peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason) + peripheral_controller.on_link_central_disconnected( + central_address, disconnect_command.reason + ) - central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) + central_controller.on_link_peripheral_disconnection_complete( + disconnect_command, HCI_SUCCESS + ) def disconnect(self, central_address, peripheral_address, disconnect_command): - logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') + logger.debug( + f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}' + ) args = [central_address, peripheral_address, disconnect_command] asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) - def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): + def on_connection_encrypted( + self, central_address, peripheral_address, rand, ediv, ltk + ): logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') if central_controller := self.find_controller(central_address): @@ -152,15 +167,18 @@ class RemoteLink: A Link implementation that communicates with other virtual controllers via a WebSocket relay ''' + def __init__(self, uri): - self.controller = None - self.uri = uri - self.execution_queue = asyncio.Queue() - self.websocket = asyncio.get_running_loop().create_future() - self.rpc_result = None - self.pending_connection = None - self.central_connections = set() # List of addresses that we have connected to - self.peripheral_connections = set() # List of addresses that have connected to us + self.controller = None + self.uri = uri + self.execution_queue = asyncio.Queue() + self.websocket = asyncio.get_running_loop().create_future() + self.rpc_result = None + self.pending_connection = None + self.central_connections = set() # List of addresses that we have connected to + self.peripheral_connections = ( + set() + ) # List of addresses that have connected to us # Connect and run asynchronously asyncio.create_task(self.run_connection()) @@ -192,7 +210,9 @@ class RemoteLink: try: await item except Exception as error: - logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}') + logger.warning( + f'{color("!!! Exception in async handler:", "red")} {error}' + ) async def run_connection(self): # Connect to the relay @@ -227,7 +247,9 @@ class RemoteLink: self.central_connections.remove(address) if address in self.peripheral_connections: - self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR) + self.controller.on_link_central_disconnected( + address, HCI_CONNECTION_TIMEOUT_ERROR + ) self.peripheral_connections.remove(address) async def on_unreachable_received(self, target): @@ -244,7 +266,9 @@ class RemoteLink: async def on_advertisement_message_received(self, sender, advertisement): try: - self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement)) + self.controller.on_link_advertising_data( + Address(sender), bytes.fromhex(advertisement) + ) except Exception: logger.exception('exception') @@ -275,7 +299,9 @@ class RemoteLink: # Notify the controller logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') - self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS) + self.controller.on_link_peripheral_connection_complete( + self.pending_connection, HCI_SUCCESS + ) async def on_disconnect_message_received(self, sender, message): # Notify the controller @@ -296,7 +322,7 @@ class RemoteLink: websocket = await self.websocket # Create a future value to hold the eventual result - assert(self.rpc_result is None) + assert self.rpc_result is None self.rpc_result = asyncio.get_running_loop().create_future() # Send the command @@ -345,16 +371,43 @@ class RemoteLink: logger.warn('connection already pending') return self.pending_connection = le_create_connection_command - self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address))) + self.execute( + partial( + self.send_connection_request_to_relay, + str(le_create_connection_command.peer_address), + ) + ) def on_disconnection_complete(self, disconnect_command): - self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) + self.controller.on_link_peripheral_disconnection_complete( + disconnect_command, HCI_SUCCESS + ) def disconnect(self, central_address, peripheral_address, disconnect_command): - logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') - self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}')) - asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command) + logger.debug( + f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}' + ) + self.execute( + partial( + self.send_targetted_message, + peripheral_address, + f'disconnect:reason={disconnect_command.reason}', + ) + ) + asyncio.get_running_loop().call_soon( + self.on_disconnection_complete, disconnect_command + ) - def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): - asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk) - self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}')) + def on_connection_encrypted( + self, central_address, peripheral_address, rand, ediv, ltk + ): + asyncio.get_running_loop().call_soon( + self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk + ) + self.execute( + partial( + self.send_targetted_message, + peripheral_address, + f'encrypted:ltk={ltk.hex()}', + ) + ) diff --git a/bumble/profiles/asha_service.py b/bumble/profiles/asha_service.py index becfde49..e5565bfc 100644 --- a/bumble/profiles/asha_service.py +++ b/bumble/profiles/asha_service.py @@ -29,7 +29,7 @@ from ..gatt import ( TemplateService, Characteristic, CharacteristicValue, - PackedCharacteristicAdapter + PackedCharacteristicAdapter, ) # ----------------------------------------------------------------------------- @@ -66,7 +66,8 @@ class AshaService(TemplateService): # Start audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]] logger.info( - f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}') + f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}' + ) elif opcode == AshaService.OPCODE_STOP: logger.info('### STOP') elif opcode == AshaService.OPCODE_STATUS: @@ -79,34 +80,36 @@ class AshaService(TemplateService): GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - bytes([ - AshaService.PROTOCOL_VERSION, # Version - self.capability, - ]) + - bytes(self.hisyncid) + - bytes(AshaService.FEATURE_MAP) + - bytes(AshaService.RENDER_DELAY) + - bytes(AshaService.RESERVED_FOR_FUTURE_USE) + - bytes(AshaService.SUPPORTED_CODEC_ID) + bytes( + [ + AshaService.PROTOCOL_VERSION, # Version + self.capability, + ] + ) + + bytes(self.hisyncid) + + bytes(AshaService.FEATURE_MAP) + + bytes(AshaService.RENDER_DELAY) + + bytes(AshaService.RESERVED_FOR_FUTURE_USE) + + bytes(AshaService.SUPPORTED_CODEC_ID), ) self.audio_control_point_characteristic = Characteristic( GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC, Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.WRITEABLE, - CharacteristicValue(write=on_audio_control_point_write) + CharacteristicValue(write=on_audio_control_point_write), ) self.audio_status_characteristic = Characteristic( GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC, Characteristic.READ | Characteristic.NOTIFY, Characteristic.READABLE, - bytes([0]) + bytes([0]), ) self.volume_characteristic = Characteristic( GATT_ASHA_VOLUME_CHARACTERISTIC, Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.WRITEABLE, - CharacteristicValue(write=on_volume_write) + CharacteristicValue(write=on_volume_write), ) # TODO add real psm value @@ -116,26 +119,39 @@ class AshaService(TemplateService): GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - struct.pack(' 0xFFFF: raise ValueError('heart_rate out of range') - if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF): + if energy_expended is not None and ( + energy_expended < 0 or energy_expended > 0xFFFF + ): raise ValueError('energy_expended out of range') if rr_intervals: @@ -69,10 +71,10 @@ class HeartRateService(TemplateService): if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: raise ValueError('rr_intervals out of range') - self.heart_rate = heart_rate + self.heart_rate = heart_rate self.sensor_contact_detected = sensor_contact_detected - self.energy_expended = energy_expended - self.rr_intervals = rr_intervals + self.energy_expended = energy_expended + self.rr_intervals = rr_intervals @classmethod def from_bytes(cls, data): @@ -87,7 +89,7 @@ class HeartRateService(TemplateService): offset += 1 if flags & (1 << 2): - sensor_contact_detected = (flags & (1 << 1) != 0) + sensor_contact_detected = flags & (1 << 1) != 0 else: sensor_contact_detected = None @@ -119,38 +121,42 @@ class HeartRateService(TemplateService): flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2) if self.energy_expended is not None: - flags |= (1 << 3) + flags |= 1 << 3 data += struct.pack('> 1) - value = data[3:3 + length] + value = data[3 : 3 + length] return (type, c_r, value) @staticmethod def make_mcc(type, c_r, data): - return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data + return ( + bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + + data + ) @staticmethod def sabm(c_r, dlci): @@ -169,8 +176,10 @@ class RFCOMM_Frame: return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1) @staticmethod - def uih(c_r, dlci, information, p_f = 0): - return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1)) + def uih(c_r, dlci, information, p_f=0): + return RFCOMM_Frame( + RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1) + ) @staticmethod def from_bytes(data): @@ -197,7 +206,12 @@ class RFCOMM_Frame: return frame def __bytes__(self): - return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs]) + return ( + bytes([self.address, self.control]) + + self.length + + self.information + + bytes([self.fcs]) + ) def __str__(self): return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})' @@ -205,38 +219,49 @@ class RFCOMM_Frame: # ----------------------------------------------------------------------------- class RFCOMM_MCC_PN: - def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size): - self.dlci = dlci - self.cl = cl - self.priority = priority - self.ack_timer = ack_timer - self.max_frame_size = max_frame_size + def __init__( + self, + dlci, + cl, + priority, + ack_timer, + max_frame_size, + max_retransmissions, + window_size, + ): + self.dlci = dlci + self.cl = cl + self.priority = priority + self.ack_timer = ack_timer + self.max_frame_size = max_frame_size self.max_retransmissions = max_retransmissions - self.window_size = window_size + self.window_size = window_size @staticmethod def from_bytes(data): return RFCOMM_MCC_PN( - dlci = data[0], - cl = data[1], - priority = data[2], - ack_timer = data[3], - max_frame_size = data[4] | data[5] << 8, - max_retransmissions = data[6], - window_size = data[7] + dlci=data[0], + cl=data[1], + priority=data[2], + ack_timer=data[3], + max_frame_size=data[4] | data[5] << 8, + max_retransmissions=data[6], + window_size=data[7], ) def __bytes__(self): - return bytes([ - self.dlci & 0xFF, - self.cl & 0xFF, - self.priority & 0xFF, - self.ack_timer & 0xFF, - self.max_frame_size & 0xFF, - (self.max_frame_size >> 8) & 0xFF, - self.max_retransmissions & 0xFF, - self.window_size & 0xFF - ]) + return bytes( + [ + self.dlci & 0xFF, + self.cl & 0xFF, + self.priority & 0xFF, + self.ack_timer & 0xFF, + self.max_frame_size & 0xFF, + (self.max_frame_size >> 8) & 0xFF, + self.max_retransmissions & 0xFF, + self.window_size & 0xFF, + ] + ) def __str__(self): return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})' @@ -246,28 +271,35 @@ class RFCOMM_MCC_PN: class RFCOMM_MCC_MSC: def __init__(self, dlci, fc, rtc, rtr, ic, dv): self.dlci = dlci - self.fc = fc - self.rtc = rtc - self.rtr = rtr - self.ic = ic - self.dv = dv + self.fc = fc + self.rtc = rtc + self.rtr = rtr + self.ic = ic + self.dv = dv @staticmethod def from_bytes(data): return RFCOMM_MCC_MSC( - dlci = data[0] >> 2, - fc = data[1] >> 1 & 1, - rtc = data[1] >> 2 & 1, - rtr = data[1] >> 3 & 1, - ic = data[1] >> 6 & 1, - dv = data[1] >> 7 & 1 + dlci=data[0] >> 2, + fc=data[1] >> 1 & 1, + rtc=data[1] >> 2 & 1, + rtr=data[1] >> 3 & 1, + ic=data[1] >> 6 & 1, + dv=data[1] >> 7 & 1, ) def __bytes__(self): - return bytes([ - (self.dlci << 2) | 3, - 1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7 - ]) + return bytes( + [ + (self.dlci << 2) | 3, + 1 + | self.fc << 1 + | self.rtc << 2 + | self.rtr << 3 + | self.ic << 6 + | self.dv << 7, + ] + ) def __str__(self): return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})' @@ -276,45 +308,49 @@ class RFCOMM_MCC_MSC: # ----------------------------------------------------------------------------- class DLC(EventEmitter): # States - INIT = 0x00 - CONNECTING = 0x01 - CONNECTED = 0x02 + INIT = 0x00 + CONNECTING = 0x01 + CONNECTED = 0x02 DISCONNECTING = 0x03 - DISCONNECTED = 0x04 - RESET = 0x05 + DISCONNECTED = 0x04 + RESET = 0x05 STATE_NAMES = { - INIT: 'INIT', - CONNECTING: 'CONNECTING', - CONNECTED: 'CONNECTED', + INIT: 'INIT', + CONNECTING: 'CONNECTING', + CONNECTED: 'CONNECTED', DISCONNECTING: 'DISCONNECTING', - DISCONNECTED: 'DISCONNECTED', - RESET: 'RESET' + DISCONNECTED: 'DISCONNECTED', + RESET: 'RESET', } def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits): super().__init__() - self.multiplexer = multiplexer - self.dlci = dlci - self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS + self.multiplexer = multiplexer + self.dlci = dlci + self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS self.rx_threshold = self.rx_credits // 2 - self.tx_credits = initial_tx_credits - self.tx_buffer = b'' - self.state = DLC.INIT - self.role = multiplexer.role - self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0 - self.sink = None + self.tx_credits = initial_tx_credits + self.tx_buffer = b'' + self.state = DLC.INIT + self.role = multiplexer.role + self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0 + self.sink = None # Compute the MTU max_overhead = 4 + 1 # header with 2-byte length + fcs - self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead) + self.mtu = min( + max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead + ) @staticmethod def state_name(state): return DLC.STATE_NAMES[state] def change_state(self, new_state): - logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}') + logger.debug( + f'{self} state change -> {color(self.state_name(new_state), "magenta")}' + ) self.state = new_state def send_frame(self, frame): @@ -329,26 +365,13 @@ class DLC(EventEmitter): logger.warn(color('!!! received SABM when not in CONNECTING state', 'red')) return - self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci)) + self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) # Exchange the modem status with the peer - msc = RFCOMM_MCC_MSC( - dlci = self.dlci, - fc = 0, - rtc = 1, - rtr = 1, - ic = 0, - dv = 1 - ) - mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc)) + msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1) + mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)) logger.debug(f'>>> MCC MSC Command: {msc}') - self.send_frame( - RFCOMM_Frame.uih( - c_r = self.c_r, - dlci = 0, - information = mcc - ) - ) + self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.change_state(DLC.CONNECTED) self.emit('open') @@ -359,23 +382,10 @@ class DLC(EventEmitter): return # Exchange the modem status with the peer - msc = RFCOMM_MCC_MSC( - dlci = self.dlci, - fc = 0, - rtc = 1, - rtr = 1, - ic = 0, - dv = 1 - ) - mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc)) + msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1) + mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)) logger.debug(f'>>> MCC MSC Command: {msc}') - self.send_frame( - RFCOMM_Frame.uih( - c_r = self.c_r, - dlci = 0, - information = mcc - ) - ) + self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.change_state(DLC.CONNECTED) self.multiplexer.on_dlc_open_complete(self) @@ -386,7 +396,7 @@ class DLC(EventEmitter): def on_disc_frame(self, frame): # TODO: handle all states - self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci)) + self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) def on_uih_frame(self, frame): data = frame.information @@ -395,10 +405,14 @@ class DLC(EventEmitter): credits = frame.information[0] self.tx_credits += credits - logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}') + logger.debug( + f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}' + ) data = data[1:] - logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}') + logger.debug( + f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}' + ) if len(data) and self.sink: self.sink(data) @@ -418,23 +432,12 @@ class DLC(EventEmitter): if c_r: # Command logger.debug(f'<<< MCC MSC Command: {msc}') - msc = RFCOMM_MCC_MSC( - dlci = self.dlci, - fc = 0, - rtc = 1, - rtr = 1, - ic = 0, - dv = 1 + msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1) + mcc = RFCOMM_Frame.make_mcc( + type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc) ) - mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc)) logger.debug(f'>>> MCC MSC Response: {msc}') - self.send_frame( - RFCOMM_Frame.uih( - c_r = self.c_r, - dlci = 0, - information = mcc - ) - ) + self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) else: # Response logger.debug(f'<<< MCC MSC Response: {msc}') @@ -445,35 +448,24 @@ class DLC(EventEmitter): self.change_state(DLC.CONNECTING) self.connection_result = asyncio.get_running_loop().create_future() - self.send_frame( - RFCOMM_Frame.sabm( - c_r = self.c_r, - dlci = self.dlci - ) - ) + self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci)) def accept(self): if not self.state == DLC.INIT: raise InvalidStateError('invalid state') pn = RFCOMM_MCC_PN( - dlci = self.dlci, - cl = 0xE0, - priority = 7, - ack_timer = 0, - max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU, - max_retransmissions = 0, - window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS + dlci=self.dlci, + cl=0xE0, + priority=7, + ack_timer=0, + max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, + max_retransmissions=0, + window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS, ) - mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 0, data = bytes(pn)) + mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn)) logger.debug(f'>>> PN Response: {pn}') - self.send_frame( - RFCOMM_Frame.uih( - c_r = self.c_r, - dlci = 0, - information = mcc - ) - ) + self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.change_state(DLC.CONNECTING) def rx_credits_needed(self): @@ -488,13 +480,13 @@ class DLC(EventEmitter): while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0: # Get the next chunk, up to MTU size if rx_credits_needed > 0: - chunk = bytes([rx_credits_needed]) + self.tx_buffer[:self.mtu - 1] - self.tx_buffer = self.tx_buffer[len(chunk) - 1:] + chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1] + self.tx_buffer = self.tx_buffer[len(chunk) - 1 :] self.rx_credits += rx_credits_needed - tx_credit_spent = (len(chunk) > 1) + tx_credit_spent = len(chunk) > 1 else: - chunk = self.tx_buffer[:self.mtu] - self.tx_buffer = self.tx_buffer[len(chunk):] + chunk = self.tx_buffer[: self.mtu] + self.tx_buffer = self.tx_buffer[len(chunk) :] tx_credit_spent = True # Update the tx credits @@ -503,13 +495,15 @@ class DLC(EventEmitter): self.tx_credits -= 1 # Send the frame - logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}') + logger.debug( + f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}' + ) self.send_frame( RFCOMM_Frame.uih( - c_r = self.c_r, - dlci = self.dlci, - information = chunk, - p_f = 1 if rx_credits_needed > 0 else 0 + c_r=self.c_r, + dlci=self.dlci, + information=chunk, + p_f=1 if rx_credits_needed > 0 else 0, ) ) @@ -543,34 +537,34 @@ class Multiplexer(EventEmitter): RESPONDER = 0x01 # States - INIT = 0x00 - CONNECTING = 0x01 - CONNECTED = 0x02 - OPENING = 0x03 + INIT = 0x00 + CONNECTING = 0x01 + CONNECTED = 0x02 + OPENING = 0x03 DISCONNECTING = 0x04 - DISCONNECTED = 0x05 - RESET = 0x06 + DISCONNECTED = 0x05 + RESET = 0x06 STATE_NAMES = { - INIT: 'INIT', - CONNECTING: 'CONNECTING', - CONNECTED: 'CONNECTED', - OPENING: 'OPENING', + INIT: 'INIT', + CONNECTING: 'CONNECTING', + CONNECTED: 'CONNECTED', + OPENING: 'OPENING', DISCONNECTING: 'DISCONNECTING', - DISCONNECTED: 'DISCONNECTED', - RESET: 'RESET' + DISCONNECTED: 'DISCONNECTED', + RESET: 'RESET', } def __init__(self, l2cap_channel, role): super().__init__() - self.role = role - self.l2cap_channel = l2cap_channel - self.state = Multiplexer.INIT - self.dlcs = {} # DLCs, by DLCI - self.connection_result = None + self.role = role + self.l2cap_channel = l2cap_channel + self.state = Multiplexer.INIT + self.dlcs = {} # DLCs, by DLCI + self.connection_result = None self.disconnection_result = None - self.open_result = None - self.acceptor = None + self.open_result = None + self.acceptor = None # Become a sink for the L2CAP channel l2cap_channel.sink = self.on_pdu @@ -580,7 +574,9 @@ class Multiplexer(EventEmitter): return Multiplexer.STATE_NAMES[state] def change_state(self, new_state): - logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}') + logger.debug( + f'{self} state change -> {color(self.state_name(new_state), "cyan")}' + ) self.state = new_state def send_frame(self, frame): @@ -616,7 +612,7 @@ class Multiplexer(EventEmitter): logger.debug('not in INIT state, ignoring SABM') return self.change_state(Multiplexer.CONNECTED) - self.send_frame(RFCOMM_Frame.ua(c_r = 1, dlci = 0)) + self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0)) def on_ua_frame(self, frame): if self.state == Multiplexer.CONNECTING: @@ -634,18 +630,22 @@ class Multiplexer(EventEmitter): if self.state == Multiplexer.OPENING: self.change_state(Multiplexer.CONNECTED) if self.open_result: - self.open_result.set_exception(ConnectionError( - ConnectionError.CONNECTION_REFUSED, - BT_BR_EDR_TRANSPORT, - self.l2cap_channel.connection.peer_address, - 'rfcomm' - )) + self.open_result.set_exception( + ConnectionError( + ConnectionError.CONNECTION_REFUSED, + BT_BR_EDR_TRANSPORT, + self.l2cap_channel.connection.peer_address, + 'rfcomm', + ) + ) else: logger.warn(f'unexpected state for DM: {self}') def on_disc_frame(self, frame): self.change_state(Multiplexer.DISCONNECTED) - self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0)) + self.send_frame( + RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0) + ) def on_uih_frame(self, frame): (type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) @@ -685,7 +685,7 @@ class Multiplexer(EventEmitter): dlc.accept() else: # No acceptor, we're in Disconnected Mode - self.send_frame(RFCOMM_Frame.dm(c_r = 1, dlci = pn.dlci)) + self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci)) else: # No acceptor?? shouldn't happen logger.warn(color('!!! no acceptor registered', 'red')) @@ -712,7 +712,7 @@ class Multiplexer(EventEmitter): self.change_state(Multiplexer.CONNECTING) self.connection_result = asyncio.get_running_loop().create_future() - self.send_frame(RFCOMM_Frame.sabm(c_r = 1, dlci = 0)) + self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0)) return await self.connection_result async def disconnect(self): @@ -721,7 +721,11 @@ class Multiplexer(EventEmitter): self.disconnection_result = asyncio.get_running_loop().create_future() self.change_state(Multiplexer.DISCONNECTING) - self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0)) + self.send_frame( + RFCOMM_Frame.disc( + c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0 + ) + ) await self.disconnection_result async def open_dlc(self, channel): @@ -732,23 +736,23 @@ class Multiplexer(EventEmitter): raise InvalidStateError('not connected') pn = RFCOMM_MCC_PN( - dlci = channel << 1, - cl = 0xF0, - priority = 7, - ack_timer = 0, - max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU, - max_retransmissions = 0, - window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS + dlci=channel << 1, + cl=0xF0, + priority=7, + ack_timer=0, + max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, + max_retransmissions=0, + window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS, ) - mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 1, data = bytes(pn)) + mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn)) logger.debug(f'>>> Sending MCC: {pn}') self.open_result = asyncio.get_running_loop().create_future() self.change_state(Multiplexer.OPENING) self.send_frame( RFCOMM_Frame.uih( - c_r = 1 if self.role == Multiplexer.INITIATOR else 0, - dlci = 0, - information = mcc + c_r=1 if self.role == Multiplexer.INITIATOR else 0, + dlci=0, + information=mcc, ) ) result = await self.open_result @@ -768,15 +772,17 @@ class Multiplexer(EventEmitter): # ----------------------------------------------------------------------------- class Client: def __init__(self, device, connection): - self.device = device - self.connection = connection + self.device = device + self.connection = connection self.l2cap_channel = None - self.multiplexer = None + self.multiplexer = None async def start(self): # Create a new L2CAP connection try: - self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM) + self.l2cap_channel = await self.device.l2cap_channel_manager.connect( + self.connection, RFCOMM_PSM + ) except ProtocolError as error: logger.warn(f'L2CAP connection failed: {error}') raise @@ -802,16 +808,18 @@ class Client: class Server(EventEmitter): def __init__(self, device): super().__init__() - self.device = device + self.device = device self.multiplexer = None - self.acceptors = {} + self.acceptors = {} # Register ourselves with the L2CAP channel manager device.register_l2cap_server(RFCOMM_PSM, self.on_connection) def listen(self, acceptor): # Find a free channel number - for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1): + for channel in range( + RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1 + ): if channel not in self.acceptors: self.acceptors[channel] = acceptor return channel diff --git a/bumble/sdp.py b/bumble/sdp.py index 935561e4..460b8e9e 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do SDP_PSM = 0x0001 @@ -112,48 +114,64 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') # To be used in searches where an attribute ID list allows a range to be specified SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size +# fmt: on + # ----------------------------------------------------------------------------- class DataElement: - NIL = 0 + NIL = 0 UNSIGNED_INTEGER = 1 - SIGNED_INTEGER = 2 - UUID = 3 - TEXT_STRING = 4 - BOOLEAN = 5 - SEQUENCE = 6 - ALTERNATIVE = 7 - URL = 8 + SIGNED_INTEGER = 2 + UUID = 3 + TEXT_STRING = 4 + BOOLEAN = 5 + SEQUENCE = 6 + ALTERNATIVE = 7 + URL = 8 TYPE_NAMES = { - NIL: 'NIL', + NIL: 'NIL', UNSIGNED_INTEGER: 'UNSIGNED_INTEGER', - SIGNED_INTEGER: 'SIGNED_INTEGER', - UUID: 'UUID', - TEXT_STRING: 'TEXT_STRING', - BOOLEAN: 'BOOLEAN', - SEQUENCE: 'SEQUENCE', - ALTERNATIVE: 'ALTERNATIVE', - URL: 'URL' + SIGNED_INTEGER: 'SIGNED_INTEGER', + UUID: 'UUID', + TEXT_STRING: 'TEXT_STRING', + BOOLEAN: 'BOOLEAN', + SEQUENCE: 'SEQUENCE', + ALTERNATIVE: 'ALTERNATIVE', + URL: 'URL', } type_constructors = { - NIL: lambda x: DataElement(DataElement.NIL, None), - UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y), - SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y), - UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))), - TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')), - BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), - SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)), - ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)), - URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')) + NIL: lambda x: DataElement(DataElement.NIL, None), + UNSIGNED_INTEGER: lambda x, y: DataElement( + DataElement.UNSIGNED_INTEGER, + DataElement.unsigned_integer_from_bytes(x), + value_size=y, + ), + SIGNED_INTEGER: lambda x, y: DataElement( + DataElement.SIGNED_INTEGER, + DataElement.signed_integer_from_bytes(x), + value_size=y, + ), + UUID: lambda x: DataElement( + DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x))) + ), + TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')), + BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), + SEQUENCE: lambda x: DataElement( + DataElement.SEQUENCE, DataElement.list_from_bytes(x) + ), + ALTERNATIVE: lambda x: DataElement( + DataElement.ALTERNATIVE, DataElement.list_from_bytes(x) + ), + URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')), } def __init__(self, type, value, value_size=None): - self.type = type - self.value = value + self.type = type + self.value = value self.value_size = value_size - self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica + self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER: if value_size is None: raise ValueError('integer types must have a value size specified') @@ -250,7 +268,7 @@ class DataElement: while data: element = DataElement.from_bytes(data) elements.append(element) - data = data[len(bytes(element)):] + data = data[len(bytes(element)) :] return elements @staticmethod @@ -261,7 +279,7 @@ class DataElement: @staticmethod def from_bytes(data): type = data[0] >> 3 - size_index = data[0] & 7 + size_index = data[0] & 7 value_offset = 0 if size_index == 0: if type == DataElement.NIL: @@ -286,16 +304,21 @@ class DataElement: value_size = struct.unpack('>I', data[1:5])[0] value_offset = 4 - value_data = data[1 + value_offset:1 + value_offset + value_size] + value_data = data[1 + value_offset : 1 + value_offset + value_size] constructor = DataElement.type_constructors.get(type) if constructor: - if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER: + if ( + type == DataElement.UNSIGNED_INTEGER + or type == DataElement.SIGNED_INTEGER + ): result = constructor(value_data, value_size) else: result = constructor(value_data) else: result = DataElement(type, value_data) - result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica + result.bytes = data[ + : 1 + value_offset + value_size + ] # Keep a copy so we can re-serialize to an exact replica return result def to_bytes(self): @@ -349,9 +372,11 @@ class DataElement: if size != 0: raise ValueError('NIL must be empty') size_index = 0 - elif (self.type == DataElement.UNSIGNED_INTEGER or - self.type == DataElement.SIGNED_INTEGER or - self.type == DataElement.UUID): + elif ( + self.type == DataElement.UNSIGNED_INTEGER + or self.type == DataElement.SIGNED_INTEGER + or self.type == DataElement.UUID + ): if size <= 1: size_index = 0 elif size == 2: @@ -364,10 +389,12 @@ class DataElement: size_index = 4 else: raise ValueError('invalid data size') - elif (self.type == DataElement.TEXT_STRING or - self.type == DataElement.SEQUENCE or - self.type == DataElement.ALTERNATIVE or - self.type == DataElement.URL): + elif ( + self.type == DataElement.TEXT_STRING + or self.type == DataElement.SEQUENCE + or self.type == DataElement.ALTERNATIVE + or self.type == DataElement.URL + ): if size <= 0xFF: size_index = 5 size_bytes = bytes([size]) @@ -396,7 +423,10 @@ class DataElement: container_separator = '\n' if pretty else '' element_separator = '\n' if pretty else ',' value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]' - elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER: + elif ( + self.type == DataElement.UNSIGNED_INTEGER + or self.type == DataElement.SIGNED_INTEGER + ): value_string = f'{self.value}#{self.value_size}' elif isinstance(self.value, DataElement): value_string = self.value.to_string(pretty, indentation) @@ -411,14 +441,14 @@ class DataElement: # ----------------------------------------------------------------------------- class ServiceAttribute: def __init__(self, id, value): - self.id = id + self.id = id self.value = value @staticmethod def list_from_data_elements(elements): attribute_list = [] for i in range(0, len(elements) // 2): - attribute_id, attribute_value = elements[2 * i:2 * (i + 1)] + attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] if attribute_id.type != DataElement.UNSIGNED_INTEGER: logger.warn('attribute ID element is not an integer') continue @@ -428,7 +458,14 @@ class ServiceAttribute: @staticmethod def find_attribute_in_list(attribute_list, attribute_id): - return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None) + return next( + ( + attribute.value + for attribute in attribute_list + if attribute.id == attribute_id + ), + None, + ) @staticmethod def id_name(id): @@ -462,6 +499,7 @@ class SDP_PDU: ''' See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT ''' + sdp_pdu_classes = {} @staticmethod @@ -484,13 +522,15 @@ class SDP_PDU: @staticmethod def parse_service_record_handle_list_preceded_by_count(data, offset): count = struct.unpack_from('>H', data, offset - 2)[0] - handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)] + handle_list = [ + struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) + ] return offset + count * 4, handle_list @staticmethod def parse_bytes_preceded_by_length(data, offset): length = struct.unpack_from('>H', data, offset - 2)[0] - return offset + length, data[offset:offset + length] + return offset + length, data[offset : offset + length] @staticmethod def error_name(error_code): @@ -532,7 +572,10 @@ class SDP_PDU: HCI_Object.init_from_fields(self, self.fields, kwargs) if pdu is None: parameters = HCI_Object.dict_to_bytes(kwargs, self.fields) - pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters + pdu = ( + struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + + parameters + ) self.pdu = pdu self.transaction_id = transaction_id @@ -555,9 +598,7 @@ class SDP_PDU: # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('error_code', {'size': 2, 'mapper': SDP_PDU.error_name}) -]) +@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})]) class SDP_ErrorResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU @@ -565,11 +606,13 @@ class SDP_ErrorResponse(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('service_search_pattern', DataElement.parse_from_bytes), - ('maximum_service_record_count', '>2'), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('service_search_pattern', DataElement.parse_from_bytes), + ('maximum_service_record_count', '>2'), + ('continuation_state', '*'), + ] +) class SDP_ServiceSearchRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU @@ -577,12 +620,17 @@ class SDP_ServiceSearchRequest(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('total_service_record_count', '>2'), - ('current_service_record_count', '>2'), - ('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('total_service_record_count', '>2'), + ('current_service_record_count', '>2'), + ( + 'service_record_handle_list', + SDP_PDU.parse_service_record_handle_list_preceded_by_count, + ), + ('continuation_state', '*'), + ] +) class SDP_ServiceSearchResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU @@ -590,12 +638,14 @@ class SDP_ServiceSearchResponse(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('service_record_handle', '>4'), - ('maximum_attribute_byte_count', '>2'), - ('attribute_id_list', DataElement.parse_from_bytes), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('service_record_handle', '>4'), + ('maximum_attribute_byte_count', '>2'), + ('attribute_id_list', DataElement.parse_from_bytes), + ('continuation_state', '*'), + ] +) class SDP_ServiceAttributeRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU @@ -603,11 +653,13 @@ class SDP_ServiceAttributeRequest(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('attribute_list_byte_count', '>2'), - ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('attribute_list_byte_count', '>2'), + ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length), + ('continuation_state', '*'), + ] +) class SDP_ServiceAttributeResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU @@ -615,12 +667,14 @@ class SDP_ServiceAttributeResponse(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('service_search_pattern', DataElement.parse_from_bytes), - ('maximum_attribute_byte_count', '>2'), - ('attribute_id_list', DataElement.parse_from_bytes), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('service_search_pattern', DataElement.parse_from_bytes), + ('maximum_attribute_byte_count', '>2'), + ('attribute_id_list', DataElement.parse_from_bytes), + ('continuation_state', '*'), + ] +) class SDP_ServiceSearchAttributeRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU @@ -628,11 +682,13 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU): # ----------------------------------------------------------------------------- -@SDP_PDU.subclass([ - ('attribute_lists_byte_count', '>2'), - ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length), - ('continuation_state', '*') -]) +@SDP_PDU.subclass( + [ + ('attribute_lists_byte_count', '>2'), + ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length), + ('continuation_state', '*'), + ] +) class SDP_ServiceSearchAttributeResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU @@ -642,9 +698,9 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU): # ----------------------------------------------------------------------------- class Client: def __init__(self, device): - self.device = device + self.device = device self.pending_request = None - self.channel = None + self.channel = None async def connect(self, connection): result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) @@ -659,7 +715,9 @@ class Client: if self.pending_request is not None: raise InvalidStateError('request already pending') - service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) + service_search_pattern = DataElement.sequence( + [DataElement.uuid(uuid) for uuid in uuids] + ) # Request and accumulate until there's no more continuation service_record_handle_list = [] @@ -668,10 +726,10 @@ class Client: while watchdog > 0: response_pdu = await self.channel.send_request( SDP_ServiceSearchRequest( - transaction_id = 0, # Transaction ID TODO: pick a real value - service_search_pattern = service_search_pattern, - maximum_service_record_count = 0xFFFF, - continuation_state = continuation_state + transaction_id=0, # Transaction ID TODO: pick a real value + service_search_pattern=service_search_pattern, + maximum_service_record_count=0xFFFF, + continuation_state=continuation_state, ) ) response = SDP_PDU.from_bytes(response_pdu) @@ -689,10 +747,14 @@ class Client: if self.pending_request is not None: raise InvalidStateError('request already pending') - service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) + service_search_pattern = DataElement.sequence( + [DataElement.uuid(uuid) for uuid in uuids] + ) attribute_id_list = DataElement.sequence( [ - DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) + DataElement.unsigned_integer( + attribute_id[0], value_size=attribute_id[1] + ) if type(attribute_id) is tuple else DataElement.unsigned_integer_16(attribute_id) for attribute_id in attribute_ids @@ -706,11 +768,11 @@ class Client: while watchdog > 0: response_pdu = await self.channel.send_request( SDP_ServiceSearchAttributeRequest( - transaction_id = 0, # Transaction ID TODO: pick a real value - service_search_pattern = service_search_pattern, - maximum_attribute_byte_count = 0xFFFF, - attribute_id_list = attribute_id_list, - continuation_state = continuation_state + transaction_id=0, # Transaction ID TODO: pick a real value + service_search_pattern=service_search_pattern, + maximum_attribute_byte_count=0xFFFF, + attribute_id_list=attribute_id_list, + continuation_state=continuation_state, ) ) response = SDP_PDU.from_bytes(response_pdu) @@ -740,7 +802,9 @@ class Client: attribute_id_list = DataElement.sequence( [ - DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) + DataElement.unsigned_integer( + attribute_id[0], value_size=attribute_id[1] + ) if type(attribute_id) is tuple else DataElement.unsigned_integer_16(attribute_id) for attribute_id in attribute_ids @@ -754,11 +818,11 @@ class Client: while watchdog > 0: response_pdu = await self.channel.send_request( SDP_ServiceAttributeRequest( - transaction_id = 0, # Transaction ID TODO: pick a real value - service_record_handle = service_record_handle, - maximum_attribute_byte_count = 0xFFFF, - attribute_id_list = attribute_id_list, - continuation_state = continuation_state + transaction_id=0, # Transaction ID TODO: pick a real value + service_record_handle=service_record_handle, + maximum_attribute_byte_count=0xFFFF, + attribute_id_list=attribute_id_list, + continuation_state=continuation_state, ) ) response = SDP_PDU.from_bytes(response_pdu) @@ -784,8 +848,8 @@ class Server: CONTINUATION_STATE = bytes([0x01, 0x43]) def __init__(self, device): - self.device = device - self.service_records = {} # Service records maps, by record handle + self.device = device + self.service_records = {} # Service records maps, by record handle self.current_response = None def register(self, l2cap_channel_manager): @@ -823,8 +887,7 @@ class Server: logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red')) self.send_response( SDP_ErrorResponse( - transaction_id = 0, - error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR + transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR ) ) @@ -840,16 +903,16 @@ class Server: logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') self.send_response( SDP_ErrorResponse( - transaction_id = sdp_pdu.transaction_id, - error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR + transaction_id=sdp_pdu.transaction_id, + error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR, ) ) else: logger.error(color('SDP Request not handled???', 'red')) self.send_response( SDP_ErrorResponse( - transaction_id = sdp_pdu.transaction_id, - error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR + transaction_id=sdp_pdu.transaction_id, + error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR, ) ) @@ -872,17 +935,18 @@ class Server: if attribute_id.value_size == 4: # Attribute ID range id_range_start = attribute_id.value >> 16 - id_range_end = attribute_id.value & 0xFFFF + id_range_end = attribute_id.value & 0xFFFF else: id_range_start = attribute_id.value - id_range_end = attribute_id.value + id_range_end = attribute_id.value attributes += [ - attribute for attribute in service + attribute + for attribute in service if attribute.id >= id_range_start and attribute.id <= id_range_end ] # Return the maching attributes, sorted by attribute id - attributes.sort(key = lambda x: x.id) + attributes.sort(key=lambda x: x.id) attribute_list = DataElement.sequence([]) for attribute in attributes: attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id)) @@ -896,8 +960,8 @@ class Server: if not self.current_response: self.send_response( SDP_ErrorResponse( - transaction_id = request.transaction_id, - error_code = SDP_INVALID_CONTINUATION_STATE_ERROR + transaction_id=request.transaction_id, + error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, ) ) return @@ -910,30 +974,38 @@ class Server: service_record_handles = list(matching_services.keys()) # Only return up to the maximum requested - service_record_handles_subset = service_record_handles[:request.maximum_service_record_count] + service_record_handles_subset = service_record_handles[ + : request.maximum_service_record_count + ] # Serialize to a byte array, and remember the total count logger.debug(f'Service Record Handles: {service_record_handles}') self.current_response = ( len(service_record_handles), - service_record_handles_subset + service_record_handles_subset, ) # Respond, keeping any unsent handles for later - service_record_handles = self.current_response[1][:request.maximum_service_record_count] + service_record_handles = self.current_response[1][ + : request.maximum_service_record_count + ] self.current_response = ( self.current_response[0], - self.current_response[1][request.maximum_service_record_count:] + self.current_response[1][request.maximum_service_record_count :], + ) + continuation_state = ( + Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) + ) + service_record_handle_list = b''.join( + [struct.pack('>I', handle) for handle in service_record_handles] ) - continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) - service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles]) self.send_response( SDP_ServiceSearchResponse( - transaction_id = request.transaction_id, - total_service_record_count = self.current_response[0], - current_service_record_count = len(service_record_handles), - service_record_handle_list = service_record_handle_list, - continuation_state = continuation_state + transaction_id=request.transaction_id, + total_service_record_count=self.current_response[0], + current_service_record_count=len(service_record_handles), + service_record_handle_list=service_record_handle_list, + continuation_state=continuation_state, ) ) @@ -943,8 +1015,8 @@ class Server: if not self.current_response: self.send_response( SDP_ErrorResponse( - transaction_id = request.transaction_id, - error_code = SDP_INVALID_CONTINUATION_STATE_ERROR + transaction_id=request.transaction_id, + error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, ) ) return @@ -957,27 +1029,31 @@ class Server: if service is None: self.send_response( SDP_ErrorResponse( - transaction_id = request.transaction_id, - error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR + transaction_id=request.transaction_id, + error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR, ) ) return # Get the attributes for the service - attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) + attribute_list = Server.get_service_attributes( + service, request.attribute_id_list.value + ) # Serialize to a byte array logger.debug(f'Attributes: {attribute_list}') self.current_response = bytes(attribute_list) # Respond, keeping any pending chunks for later - attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) + attribute_list, continuation_state = self.get_next_response_payload( + request.maximum_attribute_byte_count + ) self.send_response( SDP_ServiceAttributeResponse( - transaction_id = request.transaction_id, - attribute_list_byte_count = len(attribute_list), - attribute_list = attribute_list, - continuation_state = continuation_state + transaction_id=request.transaction_id, + attribute_list_byte_count=len(attribute_list), + attribute_list=attribute_list, + continuation_state=continuation_state, ) ) @@ -987,8 +1063,8 @@ class Server: if not self.current_response: self.send_response( SDP_ErrorResponse( - transaction_id = request.transaction_id, - error_code = SDP_INVALID_CONTINUATION_STATE_ERROR + transaction_id=request.transaction_id, + error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, ) ) else: @@ -996,12 +1072,16 @@ class Server: self.current_response = None # Find the matching services - matching_services = self.match_services(request.service_search_pattern).values() + matching_services = self.match_services( + request.service_search_pattern + ).values() # Filter the required attributes attribute_lists = DataElement.sequence([]) for service in matching_services: - attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) + attribute_list = Server.get_service_attributes( + service, request.attribute_id_list.value + ) if attribute_list.value: attribute_lists.value.append(attribute_list) @@ -1010,12 +1090,14 @@ class Server: self.current_response = bytes(attribute_lists) # Respond, keeping any pending chunks for later - attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) + attribute_lists, continuation_state = self.get_next_response_payload( + request.maximum_attribute_byte_count + ) self.send_response( SDP_ServiceSearchAttributeResponse( - transaction_id = request.transaction_id, - attribute_lists_byte_count = len(attribute_lists), - attribute_lists = attribute_lists, - continuation_state = continuation_state + transaction_id=request.transaction_id, + attribute_lists_byte_count=len(attribute_lists), + attribute_lists=attribute_lists, + continuation_state=continuation_state, ) ) diff --git a/bumble/smp.py b/bumble/smp.py index 4c6ca4eb..e9d6fe33 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -43,6 +43,8 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +# fmt: off + SMP_CID = 0x06 SMP_BR_CID = 0x07 @@ -155,6 +157,8 @@ SMP_CT2_AUTHREQ = 0b00100000 SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031') SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032') +# fmt: on + # ----------------------------------------------------------------------------- # Utils @@ -170,6 +174,7 @@ class SMP_Command: ''' See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL ''' + smp_classes = {} code = 0 @@ -196,10 +201,10 @@ class SMP_Command: @staticmethod def auth_req_str(value): bonding_flags = value & 3 - mitm = (value >> 2) & 1 - sc = (value >> 3) & 1 - keypress = (value >> 4) & 1 - ct2 = (value >> 5) & 1 + mitm = (value >> 2) & 1 + sc = (value >> 3) & 1 + keypress = (value >> 4) & 1 + ct2 = (value >> 5) & 1 return f'bonding_flags={bonding_flags}, MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}' @@ -230,7 +235,9 @@ class SMP_Command: cls.name = cls.__name__.upper() cls.code = key_with_value(SMP_COMMAND_NAMES, cls.name) if cls.code is None: - raise KeyError(f'Command name {cls.name} not found in SMP_COMMAND_NAMES') + raise KeyError( + f'Command name {cls.name} not found in SMP_COMMAND_NAMES' + ) cls.fields = fields # Register a factory for this class @@ -267,14 +274,22 @@ class SMP_Command: # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('io_capability', {'size': 1, 'mapper': SMP_Command.io_capability_name}), - ('oob_data_flag', 1), - ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), - ('maximum_encryption_key_size', 1), - ('initiator_key_distribution', {'size': 1, 'mapper': SMP_Command.key_distribution_str}), - ('responder_key_distribution', {'size': 1, 'mapper': SMP_Command.key_distribution_str}) -]) +@SMP_Command.subclass( + [ + ('io_capability', {'size': 1, 'mapper': SMP_Command.io_capability_name}), + ('oob_data_flag', 1), + ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), + ('maximum_encryption_key_size', 1), + ( + 'initiator_key_distribution', + {'size': 1, 'mapper': SMP_Command.key_distribution_str}, + ), + ( + 'responder_key_distribution', + {'size': 1, 'mapper': SMP_Command.key_distribution_str}, + ), + ] +) class SMP_Pairing_Request_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request @@ -282,14 +297,22 @@ class SMP_Pairing_Request_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('io_capability', {'size': 1, 'mapper': SMP_Command.io_capability_name}), - ('oob_data_flag', 1), - ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), - ('maximum_encryption_key_size', 1), - ('initiator_key_distribution', {'size': 1, 'mapper': SMP_Command.key_distribution_str}), - ('responder_key_distribution', {'size': 1, 'mapper': SMP_Command.key_distribution_str}) -]) +@SMP_Command.subclass( + [ + ('io_capability', {'size': 1, 'mapper': SMP_Command.io_capability_name}), + ('oob_data_flag', 1), + ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), + ('maximum_encryption_key_size', 1), + ( + 'initiator_key_distribution', + {'size': 1, 'mapper': SMP_Command.key_distribution_str}, + ), + ( + 'responder_key_distribution', + {'size': 1, 'mapper': SMP_Command.key_distribution_str}, + ), + ] +) class SMP_Pairing_Response_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response @@ -297,9 +320,7 @@ class SMP_Pairing_Response_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('confirm_value', 16) -]) +@SMP_Command.subclass([('confirm_value', 16)]) class SMP_Pairing_Confirm_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm @@ -307,9 +328,7 @@ class SMP_Pairing_Confirm_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('random_value', 16) -]) +@SMP_Command.subclass([('random_value', 16)]) class SMP_Pairing_Random_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random @@ -317,9 +336,7 @@ class SMP_Pairing_Random_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('reason', {'size': 1, 'mapper': error_name}) -]) +@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})]) class SMP_Pairing_Failed_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed @@ -327,10 +344,7 @@ class SMP_Pairing_Failed_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('public_key_x', 32), - ('public_key_y', 32) -]) +@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)]) class SMP_Pairing_Public_Key_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key @@ -338,9 +352,11 @@ class SMP_Pairing_Public_Key_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('dhkey_check', 16), -]) +@SMP_Command.subclass( + [ + ('dhkey_check', 16), + ] +) class SMP_Pairing_DHKey_Check_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check @@ -348,9 +364,14 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('notification_type', {'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}), -]) +@SMP_Command.subclass( + [ + ( + 'notification_type', + {'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}, + ), + ] +) class SMP_Pairing_Keypress_Notification_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification @@ -358,9 +379,7 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('long_term_key', 16) -]) +@SMP_Command.subclass([('long_term_key', 16)]) class SMP_Encryption_Information_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information @@ -368,10 +387,7 @@ class SMP_Encryption_Information_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('ediv', 2), - ('rand', 8) -]) +@SMP_Command.subclass([('ediv', 2), ('rand', 8)]) class SMP_Master_Identification_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification @@ -379,9 +395,7 @@ class SMP_Master_Identification_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('identity_resolving_key', 16) -]) +@SMP_Command.subclass([('identity_resolving_key', 16)]) class SMP_Identity_Information_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information @@ -389,10 +403,12 @@ class SMP_Identity_Information_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('addr_type', Address.ADDRESS_TYPE_SPEC), - ('bd_addr', Address.parse_address_preceded_by_type) -]) +@SMP_Command.subclass( + [ + ('addr_type', Address.ADDRESS_TYPE_SPEC), + ('bd_addr', Address.parse_address_preceded_by_type), + ] +) class SMP_Identity_Address_Information_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information @@ -400,9 +416,7 @@ class SMP_Identity_Address_Information_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('signature_key', 16) -]) +@SMP_Command.subclass([('signature_key', 16)]) class SMP_Signing_Information_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information @@ -410,9 +424,11 @@ class SMP_Signing_Information_Command(SMP_Command): # ----------------------------------------------------------------------------- -@SMP_Command.subclass([ - ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), -]) +@SMP_Command.subclass( + [ + ('auth_req', {'size': 1, 'mapper': SMP_Command.auth_req_str}), + ] +) class SMP_Security_Request_Command(SMP_Command): ''' See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request @@ -452,23 +468,27 @@ class AddressResolver: resolved_address_type = Address.PUBLIC_IDENTITY_ADDRESS else: resolved_address_type = Address.RANDOM_IDENTITY_ADDRESS - return Address(address=str(resolved_address), address_type=resolved_address_type) + return Address( + address=str(resolved_address), address_type=resolved_address_type + ) # ----------------------------------------------------------------------------- class PairingDelegate: - NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY - KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY - DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY - DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY + NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY + KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY + DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY + DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY - DEFAULT_KEY_DISTRIBUTION = (SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG) + DEFAULT_KEY_DISTRIBUTION = ( + SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG + ) def __init__( self, io_capability=NO_OUTPUT_NO_INPUT, local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION, - local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION + local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION, ): self.io_capability = io_capability self.local_initiator_key_distribution = local_initiator_key_distribution @@ -489,21 +509,21 @@ class PairingDelegate: async def display_number(self, number, digits=6): pass - async def key_distribution_response(self, peer_initiator_key_distribution, peer_responder_key_distribution): + async def key_distribution_response( + self, peer_initiator_key_distribution, peer_responder_key_distribution + ): return ( - (peer_initiator_key_distribution & - self.local_initiator_key_distribution), - (peer_responder_key_distribution & - self.local_responder_key_distribution) + (peer_initiator_key_distribution & self.local_initiator_key_distribution), + (peer_responder_key_distribution & self.local_responder_key_distribution), ) # ----------------------------------------------------------------------------- class PairingConfig: def __init__(self, sc=True, mitm=True, bonding=True, delegate=None): - self.sc = sc - self.mitm = mitm - self.bonding = bonding + self.sc = sc + self.mitm = mitm + self.bonding = bonding self.delegate = delegate or PairingDelegate() def __str__(self): @@ -514,16 +534,16 @@ class PairingConfig: # ----------------------------------------------------------------------------- class Session: # Pairing methods - JUST_WORKS = 0 + JUST_WORKS = 0 NUMERIC_COMPARISON = 1 - PASSKEY = 2 - OOB = 3 + PASSKEY = 2 + OOB = 3 PAIRING_METHOD_NAMES = { - JUST_WORKS: 'JUST_WORKS', + JUST_WORKS: 'JUST_WORKS', NUMERIC_COMPARISON: 'NUMERIC_COMPARISON', - PASSKEY: 'PASSKEY', - OOB: 'OOB' + PASSKEY: 'PASSKEY', + OOB: 'OOB', } # I/O Capability to pairing method decision matrix @@ -538,82 +558,96 @@ class Session: # to specify if the initiator and responder should display (True) or input a code (False). PAIRING_METHODS = { SMP_DISPLAY_ONLY_IO_CAPABILITY: { - SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, - SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), + SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, + SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS, + SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, True, False), + SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, True, False), }, SMP_DISPLAY_YES_NO_IO_CAPABILITY: { - SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, - SMP_DISPLAY_YES_NO_IO_CAPABILITY: (JUST_WORKS, NUMERIC_COMPARISON), - SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), + SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, + SMP_DISPLAY_YES_NO_IO_CAPABILITY: (JUST_WORKS, NUMERIC_COMPARISON), + SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: ((PASSKEY, True, False), NUMERIC_COMPARISON) + SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: ( + (PASSKEY, True, False), + NUMERIC_COMPARISON, + ), }, SMP_KEYBOARD_ONLY_IO_CAPABILITY: { - SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True), - SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PASSKEY, False, True), - SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, False, False), + SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True), + SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PASSKEY, False, True), + SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, False, False), SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, False, True), + SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, False, True), }, SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: { - SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, - SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_ONLY_IO_CAPABILITY: JUST_WORKS, + SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS, + SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS, + SMP_KEYBOARD_ONLY_IO_CAPABILITY: JUST_WORKS, SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: JUST_WORKS + SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: JUST_WORKS, }, SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: { - SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True), - SMP_DISPLAY_YES_NO_IO_CAPABILITY: ((PASSKEY, False, True), NUMERIC_COMPARISON), - SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), + SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True), + SMP_DISPLAY_YES_NO_IO_CAPABILITY: ( + (PASSKEY, False, True), + NUMERIC_COMPARISON, + ), + SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False), SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS, - SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: ((PASSKEY, True, False), NUMERIC_COMPARISON) - } + SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: ( + (PASSKEY, True, False), + NUMERIC_COMPARISON, + ), + }, } def __init__(self, manager, connection, pairing_config): - self.manager = manager - self.connection = connection - self.tk = bytes(16) - self.r = bytes(16) - self.stk = None - self.ltk = None - self.ltk_ediv = 0 - self.ltk_rand = bytes(8) - self.link_key = None - self.initiator_key_distribution = 0 - self.responder_key_distribution = 0 - self.peer_random_value = None - self.peer_public_key_x = bytes(32) - self.peer_public_key_y = bytes(32) - self.peer_ltk = None - self.peer_ediv = None - self.peer_rand = None + self.manager = manager + self.connection = connection + self.tk = bytes(16) + self.r = bytes(16) + self.stk = None + self.ltk = None + self.ltk_ediv = 0 + self.ltk_rand = bytes(8) + self.link_key = None + self.initiator_key_distribution = 0 + self.responder_key_distribution = 0 + self.peer_random_value = None + self.peer_public_key_x = bytes(32) + self.peer_public_key_y = bytes(32) + self.peer_ltk = None + self.peer_ediv = None + self.peer_rand = None self.peer_identity_resolving_key = None - self.peer_bd_addr = None - self.peer_signature_key = None + self.peer_bd_addr = None + self.peer_signature_key = None self.peer_expected_distributions = [] - self.dh_key = None - self.passkey = 0 - self.passkey_step = 0 - self.passkey_display = False - self.pairing_method = 0 - self.pairing_config = pairing_config - self.wait_before_continuing = None - self.completed = False - self.ctkd_task = None + self.dh_key = None + self.passkey = 0 + self.passkey_step = 0 + self.passkey_display = False + self.pairing_method = 0 + self.pairing_config = pairing_config + self.wait_before_continuing = None + self.completed = False + self.ctkd_task = None # Decide if we're the initiator or the responder - self.is_initiator = (connection.role == BT_CENTRAL_ROLE) + self.is_initiator = connection.role == BT_CENTRAL_ROLE self.is_responder = not self.is_initiator # Listen for connection events connection.on('disconnection', self.on_disconnection) - connection.on('connection_encryption_change', self.on_connection_encryption_change) - connection.on('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh) + connection.on( + 'connection_encryption_change', self.on_connection_encryption_change + ) + connection.on( + 'connection_encryption_key_refresh', + self.on_connection_encryption_key_refresh, + ) # Create a future that can be used to wait for the session to complete if self.is_initiator: @@ -622,18 +656,22 @@ class Session: self.pairing_result = None # Key Distribution (default values before negotiation) - self.initiator_key_distribution = pairing_config.delegate.local_initiator_key_distribution - self.responder_key_distribution = pairing_config.delegate.local_responder_key_distribution + self.initiator_key_distribution = ( + pairing_config.delegate.local_initiator_key_distribution + ) + self.responder_key_distribution = ( + pairing_config.delegate.local_responder_key_distribution + ) # Authentication Requirements Flags - Vol 3, Part H, Figure 3.3 - self.bonding = pairing_config.bonding - self.sc = pairing_config.sc - self.mitm = pairing_config.mitm + self.bonding = pairing_config.bonding + self.sc = pairing_config.sc + self.mitm = pairing_config.mitm self.keypress = False - self.ct2 = False + self.ct2 = False # I/O Capabilities - self.io_capability = pairing_config.delegate.io_capability + self.io_capability = pairing_config.delegate.io_capability self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY # OOB (not supported yet) @@ -643,22 +681,19 @@ class Session: self_address = connection.self_address peer_address = connection.peer_resolvable_address or connection.peer_address if self.is_initiator: - self.ia = bytes(self_address) + self.ia = bytes(self_address) self.iat = 1 if self_address.is_random else 0 - self.ra = bytes(peer_address) + self.ra = bytes(peer_address) self.rat = 1 if peer_address.is_random else 0 else: - self.ra = bytes(self_address) + self.ra = bytes(self_address) self.rat = 1 if self_address.is_random else 0 - self.ia = bytes(peer_address) + self.ia = bytes(peer_address) self.iat = 1 if peer_address.is_random else 0 @property def pkx(self): - return ( - bytes(reversed(self.manager.ecc_key.x)), - self.peer_public_key_x - ) + return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x) @property def pka(self): @@ -670,10 +705,7 @@ class Session: @property def nx(self): - return ( - self.r, - self.peer_random_value - ) + return (self.r, self.peer_random_value) @property def na(self): @@ -694,7 +726,9 @@ class Session: else: return self.ltk - def decide_pairing_method(self, auth_req, initiator_io_capability, responder_io_capability): + def decide_pairing_method( + self, auth_req, initiator_io_capability, responder_io_capability + ): if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): self.pairing_method = self.JUST_WORKS return @@ -708,7 +742,7 @@ class Session: self.pairing_method = details else: # PASSKEY method, with a method ID and display/input flags - self.pairing_method = details[0] + self.pairing_method = details[0] self.passkey_display = details[1 if self.is_initiator else 2] def check_expected_value(self, expected, received, error): @@ -738,7 +772,9 @@ class Session: async def prompt(): logger.debug(f'verification code: {code}') try: - response = await self.pairing_config.delegate.compare_numbers(code, digits=6) + response = await self.pairing_config.delegate.compare_numbers( + code, digits=6 + ) if response: next_steps() return @@ -772,7 +808,9 @@ class Session: self.tk = self.passkey.to_bytes(16, byteorder='little') logger.debug(f'TK from passkey = {self.tk.hex()}') - asyncio.create_task(self.pairing_config.delegate.display_number(self.passkey, digits=6)) + asyncio.create_task( + self.pairing_config.delegate.display_number(self.passkey, digits=6) + ) def input_passkey(self, next_steps=None): # Prompt the user for the passkey displayed on the peer @@ -785,6 +823,7 @@ class Session: if next_steps is not None: next_steps() + self.prompt_user_for_number(after_input) def display_or_input_passkey(self, next_steps=None): @@ -799,31 +838,31 @@ class Session: self.manager.send_command(self.connection, command) def send_pairing_failed(self, error): - self.send_command(SMP_Pairing_Failed_Command(reason = error)) + self.send_command(SMP_Pairing_Failed_Command(reason=error)) self.on_pairing_failure(error) def send_pairing_request_command(self): self.manager.on_session_start(self) command = SMP_Pairing_Request_Command( - io_capability = self.io_capability, - oob_data_flag = 0, - auth_req = self.auth_req, - maximum_encryption_key_size = 16, - initiator_key_distribution = self.initiator_key_distribution, - responder_key_distribution = self.responder_key_distribution + io_capability=self.io_capability, + oob_data_flag=0, + auth_req=self.auth_req, + maximum_encryption_key_size=16, + initiator_key_distribution=self.initiator_key_distribution, + responder_key_distribution=self.responder_key_distribution, ) self.preq = bytes(command) self.send_command(command) def send_pairing_response_command(self): response = SMP_Pairing_Response_Command( - io_capability = self.io_capability, - oob_data_flag = 0, - auth_req = self.auth_req, - maximum_encryption_key_size = 16, - initiator_key_distribution = self.initiator_key_distribution, - responder_key_distribution = self.responder_key_distribution + io_capability=self.io_capability, + oob_data_flag=0, + auth_req=self.auth_req, + maximum_encryption_key_size=16, + initiator_key_distribution=self.initiator_key_distribution, + responder_key_distribution=self.responder_key_distribution, ) self.pres = bytes(response) self.send_command(response) @@ -833,7 +872,10 @@ class Session: logger.debug(f'generated random: {self.r.hex()}') if self.sc: - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): z = 0 elif self.pairing_method == self.PASSKEY: z = 0x80 + ((self.passkey >> self.passkey_step) & 1) @@ -841,19 +883,9 @@ class Session: return if self.is_initiator: - confirm_value = crypto.f4( - self.pka, - self.pkb, - self.r, - bytes([z]) - ) + confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z])) else: - confirm_value = crypto.f4( - self.pkb, - self.pka, - self.r, - bytes([z]) - ) + confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z])) else: confirm_value = crypto.c1( self.tk, @@ -863,26 +895,26 @@ class Session: self.iat, self.rat, self.ia, - self.ra + self.ra, ) - self.send_command(SMP_Pairing_Confirm_Command(confirm_value = confirm_value)) + self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) def send_pairing_random_command(self): - self.send_command(SMP_Pairing_Random_Command(random_value = self.r)) + self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) def send_public_key_command(self): self.send_command( SMP_Pairing_Public_Key_Command( - public_key_x = bytes(reversed(self.manager.ecc_key.x)), - public_key_y = bytes(reversed(self.manager.ecc_key.y)) + public_key_x=bytes(reversed(self.manager.ecc_key.x)), + public_key_y=bytes(reversed(self.manager.ecc_key.y)), ) ) def send_pairing_dhkey_check_command(self): self.send_command( SMP_Pairing_DHKey_Check_Command( - dhkey_check = self.ea if self.is_initiator else self.eb + dhkey_check=self.ea if self.is_initiator else self.eb ) ) @@ -892,10 +924,10 @@ class Session: asyncio.create_task( self.manager.device.host.send_command( HCI_LE_Enable_Encryption_Command( - connection_handle = self.connection.handle, - random_number = bytes(8), - encrypted_diversifier = 0, - long_term_key = key + connection_handle=self.connection.handle, + random_number=bytes(8), + encrypted_diversifier=0, + long_term_key=key, ) ) ) @@ -903,32 +935,47 @@ class Session: async def derive_ltk(self): link_key = await self.manager.device.get_link_key(self.connection.peer_address) assert link_key is not None - ilk = crypto.h7( - salt=SMP_CTKD_H7_BRLE_SALT, - w=link_key) if self.ct2 else crypto.h6(link_key, b'tmp2') + ilk = ( + crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key) + if self.ct2 + else crypto.h6(link_key, b'tmp2') + ) self.ltk = crypto.h6(ilk, b'brle') def distribute_keys(self): # Distribute the keys as required if self.is_initiator: # CTKD: Derive LTK from LinkKey - if self.connection.transport == BT_BR_EDR_TRANSPORT and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: + if ( + self.connection.transport == BT_BR_EDR_TRANSPORT + and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG + ): self.ctkd_task = asyncio.create_task(self.derive_ltk()) elif not self.sc: # Distribute the LTK, EDIV and RAND if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: - self.send_command(SMP_Encryption_Information_Command(long_term_key=self.ltk)) - self.send_command(SMP_Master_Identification_Command(ediv=self.ltk_ediv, rand=self.ltk_rand)) + self.send_command( + SMP_Encryption_Information_Command(long_term_key=self.ltk) + ) + self.send_command( + SMP_Master_Identification_Command( + ediv=self.ltk_ediv, rand=self.ltk_rand + ) + ) # Distribute IRK & BD ADDR if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG: self.send_command( - SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk) + SMP_Identity_Information_Command( + identity_resolving_key=self.manager.device.irk + ) + ) + self.send_command( + SMP_Identity_Address_Information_Command( + addr_type=self.connection.self_address.address_type, + bd_addr=self.connection.self_address, + ) ) - self.send_command(SMP_Identity_Address_Information_Command( - addr_type = self.connection.self_address.address_type, - bd_addr = self.connection.self_address - )) # Distribute CSRK csrk = bytes(16) # FIXME: testing @@ -937,30 +984,45 @@ class Session: # CTKD, calculate BR/EDR link key if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG: - ilk = crypto.h7( - salt=SMP_CTKD_H7_LEBR_SALT, - w=self.ltk) if self.ct2 else crypto.h6(self.ltk, b'tmp1') + ilk = ( + crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk) + if self.ct2 + else crypto.h6(self.ltk, b'tmp1') + ) self.link_key = crypto.h6(ilk, b'lebr') else: # CTKD: Derive LTK from LinkKey - if self.connection.transport == BT_BR_EDR_TRANSPORT and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: + if ( + self.connection.transport == BT_BR_EDR_TRANSPORT + and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG + ): self.ctkd_task = asyncio.create_task(self.derive_ltk()) # Distribute the LTK, EDIV and RAND elif not self.sc: if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: - self.send_command(SMP_Encryption_Information_Command(long_term_key=self.ltk)) - self.send_command(SMP_Master_Identification_Command(ediv=self.ltk_ediv, rand=self.ltk_rand)) + self.send_command( + SMP_Encryption_Information_Command(long_term_key=self.ltk) + ) + self.send_command( + SMP_Master_Identification_Command( + ediv=self.ltk_ediv, rand=self.ltk_rand + ) + ) # Distribute IRK & BD ADDR if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG: self.send_command( - SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk) + SMP_Identity_Information_Command( + identity_resolving_key=self.manager.device.irk + ) + ) + self.send_command( + SMP_Identity_Address_Information_Command( + addr_type=self.connection.self_address.address_type, + bd_addr=self.connection.self_address, + ) ) - self.send_command(SMP_Identity_Address_Information_Command( - addr_type = self.connection.self_address.address_type, - bd_addr = self.connection.self_address - )) # Distribute CSRK csrk = bytes(16) # FIXME: testing @@ -969,40 +1031,59 @@ class Session: # CTKD, calculate BR/EDR link key if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG: - ilk = crypto.h7( - salt=SMP_CTKD_H7_LEBR_SALT, - w=self.ltk) if self.ct2 else crypto.h6(self.ltk, b'tmp1') + ilk = ( + crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk) + if self.ct2 + else crypto.h6(self.ltk, b'tmp1') + ) self.link_key = crypto.h6(ilk, b'lebr') def compute_peer_expected_distributions(self, key_distribution_flags): # Set our expectations for what to wait for in the key distribution phase self.peer_expected_distributions = [] if not self.sc and self.connection.transport == BT_LE_TRANSPORT: - if (key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0): - self.peer_expected_distributions.append(SMP_Encryption_Information_Command) - self.peer_expected_distributions.append(SMP_Master_Identification_Command) - if (key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0): + if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0: + self.peer_expected_distributions.append( + SMP_Encryption_Information_Command + ) + self.peer_expected_distributions.append( + SMP_Master_Identification_Command + ) + if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0: self.peer_expected_distributions.append(SMP_Identity_Information_Command) - self.peer_expected_distributions.append(SMP_Identity_Address_Information_Command) - if (key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0): + self.peer_expected_distributions.append( + SMP_Identity_Address_Information_Command + ) + if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0: self.peer_expected_distributions.append(SMP_Signing_Information_Command) - logger.debug(f'expecting distributions: {[c.__name__ for c in self.peer_expected_distributions]}') + logger.debug( + f'expecting distributions: {[c.__name__ for c in self.peer_expected_distributions]}' + ) def check_key_distribution(self, command_class): # First, check that the connection is encrypted if not self.connection.is_encrypted: - logger.warn(color('received key distribution on a non-encrypted connection', 'red')) + logger.warn( + color('received key distribution on a non-encrypted connection', 'red') + ) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) return # Check that this command class is expected if command_class in self.peer_expected_distributions: self.peer_expected_distributions.remove(command_class) - logger.debug(f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}') + logger.debug( + f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}' + ) if not self.peer_expected_distributions: self.on_peer_key_distribution_complete() else: - logger.warn(color(f'!!! unexpected key distribution command: {command_class.__name__}', 'red')) + logger.warn( + color( + f'!!! unexpected key distribution command: {command_class.__name__}', + 'red', + ) + ) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) async def pair(self): @@ -1017,8 +1098,13 @@ class Session: def on_disconnection(self, reason): self.connection.remove_listener('disconnection', self.on_disconnection) - self.connection.remove_listener('connection_encryption_change', self.on_connection_encryption_change) - self.connection.remove_listener('connection_encryption_key_refresh', self.on_connection_encryption_key_refresh) + self.connection.remove_listener( + 'connection_encryption_change', self.on_connection_encryption_change + ) + self.connection.remove_listener( + 'connection_encryption_key_refresh', + self.on_connection_encryption_key_refresh, + ) self.manager.on_session_end(self) def on_peer_key_distribution_complete(self): @@ -1069,43 +1155,37 @@ class Session: keys.address_type = peer_address.address_type authenticated = self.pairing_method != self.JUST_WORKS if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT: - keys.ltk = PairingKeys.Key( - value = self.ltk, - authenticated = authenticated - ) + keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated) else: our_ltk_key = PairingKeys.Key( - value = self.ltk, - authenticated = authenticated, - ediv = self.ltk_ediv, - rand = self.ltk_rand + value=self.ltk, + authenticated=authenticated, + ediv=self.ltk_ediv, + rand=self.ltk_rand, ) peer_ltk_key = PairingKeys.Key( - value = self.peer_ltk, - authenticated = authenticated, - ediv = self.peer_ediv, - rand = self.peer_rand + value=self.peer_ltk, + authenticated=authenticated, + ediv=self.peer_ediv, + rand=self.peer_rand, ) if self.is_initiator: - keys.ltk_central = peer_ltk_key + keys.ltk_central = peer_ltk_key keys.ltk_peripheral = our_ltk_key else: - keys.ltk_central = our_ltk_key + keys.ltk_central = our_ltk_key keys.ltk_peripheral = peer_ltk_key if self.peer_identity_resolving_key is not None: keys.irk = PairingKeys.Key( - value = self.peer_identity_resolving_key, - authenticated = authenticated + value=self.peer_identity_resolving_key, authenticated=authenticated ) if self.peer_signature_key is not None: keys.csrk = PairingKeys.Key( - value = self.peer_signature_key, - authenticated = authenticated + value=self.peer_signature_key, authenticated=authenticated ) if self.link_key is not None: keys.link_key = PairingKeys.Key( - value = self.link_key, - authenticated = authenticated + value=self.link_key, authenticated=authenticated ) self.manager.on_pairing(self, peer_address, keys) @@ -1131,7 +1211,9 @@ class Session: handler(command) except Exception as error: logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') - response = SMP_Pairing_Failed_Command(reason = SMP_UNSPECIFIED_REASON_ERROR) + response = SMP_Pairing_Failed_Command( + reason=SMP_UNSPECIFIED_REASON_ERROR + ) self.send_command(response) else: logger.error(color('SMP command not handled???', 'red')) @@ -1152,8 +1234,8 @@ class Session: # Bonding and SC require both sides to request/support it self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0) - self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) - self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0) + self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) + self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0) # Check for OOB if command.oob_data_flag != 0: @@ -1162,15 +1244,19 @@ class Session: # Decide which pairing method to use self.decide_pairing_method( - command.auth_req, - command.io_capability, - self.io_capability + command.auth_req, command.io_capability, self.io_capability + ) + logger.debug( + f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}' ) - logger.debug(f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}') # Key distribution - self.initiator_key_distribution, self.responder_key_distribution = await self.pairing_config.delegate.key_distribution_response( - command.initiator_key_distribution, command.responder_key_distribution) + ( + self.initiator_key_distribution, + self.responder_key_distribution, + ) = await self.pairing_config.delegate.key_distribution_response( + command.initiator_key_distribution, command.responder_key_distribution + ) self.compute_peer_expected_distributions(self.initiator_key_distribution) # The pairing is now starting @@ -1187,7 +1273,12 @@ class Session: # Vol 3, Part C, 5.2.2.1.3 # CTKD over BR/EDR should happen after the connection has been encrypted, # so when receiving pairing requests, responder should start distributing keys - if self.connection.transport == BT_BR_EDR_TRANSPORT and self.connection.is_encrypted and self.is_responder and accepted: + if ( + self.connection.transport == BT_BR_EDR_TRANSPORT + and self.connection.is_encrypted + and self.is_responder + and accepted + ): self.distribute_keys() def on_smp_pairing_response_command(self, command): @@ -1201,7 +1292,7 @@ class Session: # Bonding and SC require both sides to request/support it self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0) - self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) + self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) # Check for OOB if self.sc and command.oob_data_flag: @@ -1210,15 +1301,18 @@ class Session: # Decide which pairing method to use self.decide_pairing_method( - command.auth_req, - self.io_capability, - command.io_capability + command.auth_req, self.io_capability, command.io_capability + ) + logger.debug( + f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}' ) - logger.debug(f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}') # Key distribution - if (command.initiator_key_distribution & ~self.initiator_key_distribution != 0) or \ - (command.responder_key_distribution & ~self.responder_key_distribution != 0): + if ( + command.initiator_key_distribution & ~self.initiator_key_distribution != 0 + ) or ( + command.responder_key_distribution & ~self.responder_key_distribution != 0 + ): # The response isn't a subset of the request self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR) return @@ -1249,7 +1343,10 @@ class Session: self.send_pairing_confirm_command() def on_smp_pairing_confirm_command_secure_connections(self, command): - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): if self.is_initiator: self.r = crypto.r() self.send_pairing_random_command() @@ -1276,12 +1373,10 @@ class Session: self.iat, self.rat, self.ia, - self.ra + self.ra, ) if not self.check_expected_value( - self.confirm_value, - confirm_verifier, - SMP_CONFIRM_VALUE_FAILED_ERROR + self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): return @@ -1305,18 +1400,16 @@ class Session: def on_smp_pairing_random_command_secure_connections(self, command): if self.is_initiator: - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( - self.pkb, - self.pka, - command.random_value, - bytes([0]) + self.pkb, self.pka, command.random_value, bytes([0]) ) if not self.check_expected_value( - self.confirm_value, - confirm_verifier, - SMP_CONFIRM_VALUE_FAILED_ERROR + self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): return elif self.pairing_method == self.PASSKEY: @@ -1325,12 +1418,10 @@ class Session: self.pkb, self.pka, command.random_value, - bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]) + bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]), ) if not self.check_expected_value( - self.confirm_value, - confirm_verifier, - SMP_CONFIRM_VALUE_FAILED_ERROR + self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): return @@ -1343,7 +1434,10 @@ class Session: else: return else: - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): self.send_pairing_random_command() elif self.pairing_method == self.PASSKEY: # Check that the random value matches what was committed to earlier @@ -1351,12 +1445,10 @@ class Session: self.pka, self.pkb, command.random_value, - bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]) + bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]), ) if not self.check_expected_value( - self.confirm_value, - confirm_verifier, - SMP_CONFIRM_VALUE_FAILED_ERROR + self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): return @@ -1377,7 +1469,10 @@ class Session: (mac_key, self.ltk) = crypto.f5(self.dh_key, self.na, self.nb, a, b) # Compute the DH Key checks - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): ra = bytes(16) rb = ra elif self.pairing_method == self.PASSKEY: @@ -1402,7 +1497,10 @@ class Session: self.wait_before_continuing.set_result(None) # Prompt the user for confirmation if needed - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): # Compute the 6-digit code code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000 @@ -1428,10 +1526,14 @@ class Session: self.peer_public_key_y = command.public_key_y # Compute the DH key - self.dh_key = bytes(reversed(self.manager.ecc_key.dh( - bytes(reversed(command.public_key_x)), - bytes(reversed(command.public_key_y)) - ))) + self.dh_key = bytes( + reversed( + self.manager.ecc_key.dh( + bytes(reversed(command.public_key_x)), + bytes(reversed(command.public_key_y)), + ) + ) + ) logger.debug(f'DH key: {self.dh_key.hex()}') if self.is_initiator: @@ -1447,7 +1549,10 @@ class Session: else: self.send_public_key_command() - if self.pairing_method == self.JUST_WORKS or self.pairing_method == self.NUMERIC_COMPARISON: + if ( + self.pairing_method == self.JUST_WORKS + or self.pairing_method == self.NUMERIC_COMPARISON + ): # We can now send the confirmation value self.send_pairing_confirm_command() @@ -1455,14 +1560,13 @@ class Session: # Check that what we received matches what we computed earlier expected = self.eb if self.is_initiator else self.ea if not self.check_expected_value( - expected, - command.dhkey_check, - SMP_DHKEY_CHECK_FAILED_ERROR + expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR ): return if self.is_responder: if self.wait_before_continuing is not None: + async def next_steps(): await self.wait_before_continuing self.wait_before_continuing = None @@ -1507,13 +1611,15 @@ class Manager(EventEmitter): def __init__(self, device): super().__init__() - self.device = device - self.sessions = {} - self._ecc_key = None + self.device = device + self.sessions = {} + self._ecc_key = None self.pairing_config_factory = lambda connection: PairingConfig() def send_command(self, connection, command): - logger.debug(f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}') + logger.debug( + f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}' + ) cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID connection.send_l2cap_pdu(cid, command.to_bytes()) @@ -1525,9 +1631,7 @@ class Manager(EventEmitter): # Pairing disabled self.send_command( connection, - SMP_Pairing_Failed_Command( - reason = SMP_PAIRING_NOT_SUPPORTED_ERROR - ) + SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR), ) return session = Session(self, connection, pairing_config) @@ -1535,7 +1639,9 @@ class Manager(EventEmitter): # Parse the L2CAP payload into an SMP Command object command = SMP_Command.from_bytes(pdu) - logger.debug(f'<<< Received SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}') + logger.debug( + f'<<< Received SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}' + ) # Delegate the handling of the command to the session session.on_smp_command(command) @@ -1563,7 +1669,7 @@ class Manager(EventEmitter): pairing_config.mitm, pairing_config.sc, False, - False + False, ) else: auth_req = 0 @@ -1575,11 +1681,13 @@ class Manager(EventEmitter): def on_pairing(self, session, identity_address, keys): # Store the keys in the key store if self.device.keystore and identity_address is not None: + async def store_keys(): try: await self.device.keystore.update(str(identity_address), keys) except Exception as error: logger.warn(f'!!! error while storing keys: {error}') + asyncio.create_task(store_keys()) # Notify the device diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index c3bd5f8b..c5007d2a 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -38,42 +38,55 @@ async def open_transport(name): scheme, *spec = name.split(':', 1) if scheme == 'serial' and spec: from .serial import open_serial_transport + return await open_serial_transport(spec[0]) elif scheme == 'udp' and spec: from .udp import open_udp_transport + return await open_udp_transport(spec[0]) elif scheme == 'tcp-client' and spec: from .tcp_client import open_tcp_client_transport + return await open_tcp_client_transport(spec[0]) elif scheme == 'tcp-server' and spec: from .tcp_server import open_tcp_server_transport + return await open_tcp_server_transport(spec[0]) elif scheme == 'ws-client' and spec: from .ws_client import open_ws_client_transport + return await open_ws_client_transport(spec[0]) elif scheme == 'ws-server' and spec: from .ws_server import open_ws_server_transport + return await open_ws_server_transport(spec[0]) elif scheme == 'pty': from .pty import open_pty_transport + return await open_pty_transport(spec[0] if spec else None) elif scheme == 'file': from .file import open_file_transport + return await open_file_transport(spec[0] if spec else None) elif scheme == 'vhci': from .vhci import open_vhci_transport + return await open_vhci_transport(spec[0] if spec else None) elif scheme == 'hci-socket': from .hci_socket import open_hci_socket_transport + return await open_hci_socket_transport(spec[0] if spec else None) elif scheme == 'usb': from .usb import open_usb_transport + return await open_usb_transport(spec[0] if spec else None) elif scheme == 'pyusb': from .pyusb import open_pyusb_transport + return await open_pyusb_transport(spec[0] if spec else None) elif scheme == 'android-emulator': from .android_emulator import open_android_emulator_transport + return await open_android_emulator_transport(spec[0] if spec else None) else: raise ValueError('unknown transport scheme') @@ -84,7 +97,7 @@ async def open_transport_or_link(name): if name.startswith('link-relay:'): link = RemoteLink(name[11:]) await link.wait_until_connected() - controller = Controller('remote', link = link) + controller = Controller('remote', link=link) class LinkTransport(Transport): async def close(self): diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index d27aef69..f9aabccf 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -59,15 +59,10 @@ async def open_android_emulator_transport(spec): return bytes([packet.type]) + packet.packet async def write(self, packet): - await self.hci_device.write( - HCIPacket( - type = packet[0], - packet = packet[1:] - ) - ) + await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:])) # Parse the parameters - mode = 'host' + mode = 'host' server_host = 'localhost' server_port = 8554 if spec is not None: @@ -100,7 +95,7 @@ async def open_android_emulator_transport(spec): transport = PumpedTransport( PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write), - channel.close + channel.close, ) transport.start() diff --git a/bumble/transport/common.py b/bumble/transport/common.py index 0f5d27f4..a8ea6d06 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -33,10 +33,10 @@ logger = logging.getLogger(__name__) # For each packet type, the info represents: # (length-size, length-offset, unpack-type) HCI_PACKET_INFO = { - hci.HCI_COMMAND_PACKET: (1, 2, 'B'), - hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), + hci.HCI_COMMAND_PACKET: (1, 2, 'B'), + hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), - hci.HCI_EVENT_PACKET: (1, 1, 'B') + hci.HCI_EVENT_PACKET: (1, 1, 'B'), } @@ -48,7 +48,7 @@ class PacketPump: def __init__(self, reader, sink): self.reader = reader - self.sink = sink + self.sink = sink async def run(self): while True: @@ -67,41 +67,46 @@ class PacketParser: ''' In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed ''' - NEED_TYPE = 0 - NEED_LENGTH = 1 - NEED_BODY = 2 - def __init__(self, sink = None): + NEED_TYPE = 0 + NEED_LENGTH = 1 + NEED_BODY = 2 + + def __init__(self, sink=None): self.sink = sink self.extended_packet_info = {} self.reset() def reset(self): - self.state = PacketParser.NEED_TYPE + self.state = PacketParser.NEED_TYPE self.bytes_needed = 1 - self.packet = bytearray() - self.packet_info = None + self.packet = bytearray() + self.packet_info = None def feed_data(self, data): data_offset = 0 data_left = len(data) while data_left and self.bytes_needed: consumed = min(self.bytes_needed, data_left) - self.packet.extend(data[data_offset:data_offset + consumed]) - data_offset += consumed - data_left -= consumed + self.packet.extend(data[data_offset : data_offset + consumed]) + data_offset += consumed + data_left -= consumed self.bytes_needed -= consumed if self.bytes_needed == 0: if self.state == PacketParser.NEED_TYPE: packet_type = self.packet[0] - self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type) + self.packet_info = HCI_PACKET_INFO.get( + packet_type + ) or self.extended_packet_info.get(packet_type) if self.packet_info is None: raise ValueError(f'invalid packet type {packet_type}') self.state = PacketParser.NEED_LENGTH self.bytes_needed = self.packet_info[0] + self.packet_info[1] elif self.state == PacketParser.NEED_LENGTH: - body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0] + body_length = struct.unpack_from( + self.packet_info[2], self.packet, 1 + self.packet_info[1] + )[0] self.bytes_needed = body_length self.state = PacketParser.NEED_BODY @@ -111,7 +116,9 @@ class PacketParser: try: self.sink.on_packet(bytes(self.packet)) except Exception as error: - logger.warning(color(f'!!! Exception in on_packet: {error}', 'red')) + logger.warning( + color(f'!!! Exception in on_packet: {error}', 'red') + ) self.reset() def set_packet_sink(self, sink): @@ -187,6 +194,7 @@ class AsyncPipeSink: ''' Sink that forwards packets asynchronously to another sink ''' + def __init__(self, sink): self.sink = sink self.loop = asyncio.get_running_loop() @@ -202,7 +210,7 @@ class ParserSource: """ def __init__(self): - self.parser = PacketParser() + self.parser = PacketParser() self.terminated = asyncio.get_running_loop().create_future() def set_packet_sink(self, sink): @@ -237,7 +245,7 @@ class StreamPacketSink: class Transport: def __init__(self, source, sink): self.source = source - self.sink = sink + self.sink = sink async def __aenter__(self): return self @@ -258,7 +266,7 @@ class PumpedPacketSource(ParserSource): def __init__(self, receive): super().__init__() self.receive_function = receive - self.pump_task = None + self.pump_task = None def start(self): async def pump_packets(): @@ -285,8 +293,8 @@ class PumpedPacketSource(ParserSource): class PumpedPacketSink: def __init__(self, send): self.send_function = send - self.packet_queue = asyncio.Queue() - self.pump_task = None + self.packet_queue = asyncio.Queue() + self.pump_task = None def on_packet(self, packet): self.packet_queue.put_nowait(packet) diff --git a/bumble/transport/emulated_bluetooth_packets_pb2.py b/bumble/transport/emulated_bluetooth_packets_pb2.py index 9d3591de..9fafd611 100644 --- a/bumble/transport/emulated_bluetooth_packets_pb2.py +++ b/bumble/transport/emulated_bluetooth_packets_pb2.py @@ -21,32 +21,36 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') - +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3' +) _HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket'] _HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType'] -HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), { - 'DESCRIPTOR' : _HCIPACKET, - '__module__' : 'emulated_bluetooth_packets_pb2' - # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket) - }) +HCIPacket = _reflection.GeneratedProtocolMessageType( + 'HCIPacket', + (_message.Message,), + { + 'DESCRIPTOR': _HCIPACKET, + '__module__': 'emulated_bluetooth_packets_pb2' + # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket) + }, +) _sym_db.RegisterMessage(HCIPacket) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' - _HCIPACKET._serialized_start=66 - _HCIPACKET._serialized_end=317 - _HCIPACKET_PACKETTYPE._serialized_start=161 - _HCIPACKET_PACKETTYPE._serialized_end=317 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' + _HCIPACKET._serialized_start = 66 + _HCIPACKET._serialized_end = 317 + _HCIPACKET_PACKETTYPE._serialized_start = 161 + _HCIPACKET_PACKETTYPE._serialized_end = 317 # @@protoc_insertion_point(module_scope) diff --git a/bumble/transport/emulated_bluetooth_pb2.py b/bumble/transport/emulated_bluetooth_pb2.py index 4da12d53..2689a990 100644 --- a/bumble/transport/emulated_bluetooth_pb2.py +++ b/bumble/transport/emulated_bluetooth_pb2.py @@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -29,25 +30,30 @@ _sym_db = _symbol_database.Default() from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3') - +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3' +) _RAWDATA = DESCRIPTOR.message_types_by_name['RawData'] -RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), { - 'DESCRIPTOR' : _RAWDATA, - '__module__' : 'emulated_bluetooth_pb2' - # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData) - }) +RawData = _reflection.GeneratedProtocolMessageType( + 'RawData', + (_message.Message,), + { + 'DESCRIPTOR': _RAWDATA, + '__module__': 'emulated_bluetooth_pb2' + # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData) + }, +) _sym_db.RegisterMessage(RawData) _EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService'] if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001' - _RAWDATA._serialized_start=91 - _RAWDATA._serialized_end=116 - _EMULATEDBLUETOOTHSERVICE._serialized_start=119 - _EMULATEDBLUETOOTHSERVICE._serialized_end=450 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001' + _RAWDATA._serialized_start = 91 + _RAWDATA._serialized_end = 116 + _EMULATEDBLUETOOTHSERVICE._serialized_start = 119 + _EMULATEDBLUETOOTHSERVICE._serialized_end = 450 # @@protoc_insertion_point(module_scope) diff --git a/bumble/transport/emulated_bluetooth_pb2_grpc.py b/bumble/transport/emulated_bluetooth_pb2_grpc.py index cc0ce37f..c7ea6d0e 100644 --- a/bumble/transport/emulated_bluetooth_pb2_grpc.py +++ b/bumble/transport/emulated_bluetooth_pb2_grpc.py @@ -39,20 +39,20 @@ class EmulatedBluetoothServiceStub(object): channel: A grpc.Channel. """ self.registerClassicPhy = channel.stream_stream( - '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', - request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, - response_deserializer=emulated__bluetooth__pb2.RawData.FromString, - ) + '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', + request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, + response_deserializer=emulated__bluetooth__pb2.RawData.FromString, + ) self.registerBlePhy = channel.stream_stream( - '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', - request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, - response_deserializer=emulated__bluetooth__pb2.RawData.FromString, - ) + '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', + request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, + response_deserializer=emulated__bluetooth__pb2.RawData.FromString, + ) self.registerHCIDevice = channel.stream_stream( - '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', - request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, - response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, - ) + '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', + request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, + response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, + ) class EmulatedBluetoothServiceServicer(object): @@ -121,28 +121,29 @@ class EmulatedBluetoothServiceServicer(object): def add_EmulatedBluetoothServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'registerClassicPhy': grpc.stream_stream_rpc_method_handler( - servicer.registerClassicPhy, - request_deserializer=emulated__bluetooth__pb2.RawData.FromString, - response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, - ), - 'registerBlePhy': grpc.stream_stream_rpc_method_handler( - servicer.registerBlePhy, - request_deserializer=emulated__bluetooth__pb2.RawData.FromString, - response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, - ), - 'registerHCIDevice': grpc.stream_stream_rpc_method_handler( - servicer.registerHCIDevice, - request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, - response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, - ), + 'registerClassicPhy': grpc.stream_stream_rpc_method_handler( + servicer.registerClassicPhy, + request_deserializer=emulated__bluetooth__pb2.RawData.FromString, + response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, + ), + 'registerBlePhy': grpc.stream_stream_rpc_method_handler( + servicer.registerBlePhy, + request_deserializer=emulated__bluetooth__pb2.RawData.FromString, + response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, + ), + 'registerHCIDevice': grpc.stream_stream_rpc_method_handler( + servicer.registerHCIDevice, + request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, + response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers) + 'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class EmulatedBluetoothService(object): """An Emulated Bluetooth Service exposes the emulated bluetooth chip from the android emulator. It allows you to register emulated bluetooth devices and @@ -156,52 +157,88 @@ class EmulatedBluetoothService(object): """ @staticmethod - def registerClassicPhy(request_iterator, + def registerClassicPhy( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', + '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def registerBlePhy(request_iterator, + def registerBlePhy( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', + '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def registerHCIDevice(request_iterator, + def registerHCIDevice( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', + '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/bumble/transport/emulated_bluetooth_vhci_pb2.py b/bumble/transport/emulated_bluetooth_vhci_pb2.py index a6384399..d046577e 100644 --- a/bumble/transport/emulated_bluetooth_vhci_pb2.py +++ b/bumble/transport/emulated_bluetooth_vhci_pb2.py @@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -29,15 +30,16 @@ _sym_db = _symbol_database.Default() import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') - +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3' +) _VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService'] if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' - _VHCIFORWARDINGSERVICE._serialized_start=96 - _VHCIFORWARDINGSERVICE._serialized_end=217 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' + _VHCIFORWARDINGSERVICE._serialized_start = 96 + _VHCIFORWARDINGSERVICE._serialized_end = 217 # @@protoc_insertion_point(module_scope) diff --git a/bumble/transport/emulated_bluetooth_vhci_pb2_grpc.py b/bumble/transport/emulated_bluetooth_vhci_pb2_grpc.py index 94140d7e..94952178 100644 --- a/bumble/transport/emulated_bluetooth_vhci_pb2_grpc.py +++ b/bumble/transport/emulated_bluetooth_vhci_pb2_grpc.py @@ -35,10 +35,10 @@ class VhciForwardingServiceStub(object): channel: A grpc.Channel. """ self.attachVhci = channel.stream_stream( - '/android.emulation.bluetooth.VhciForwardingService/attachVhci', - request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, - response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, - ) + '/android.emulation.bluetooth.VhciForwardingService/attachVhci', + request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, + response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, + ) class VhciForwardingServiceServicer(object): @@ -75,18 +75,19 @@ class VhciForwardingServiceServicer(object): def add_VhciForwardingServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'attachVhci': grpc.stream_stream_rpc_method_handler( - servicer.attachVhci, - request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, - response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, - ), + 'attachVhci': grpc.stream_stream_rpc_method_handler( + servicer.attachVhci, + request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, + response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers) + 'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class VhciForwardingService(object): """This is a service which allows you to directly intercept the VHCI packets that are coming and going to the device before they are delivered to @@ -97,18 +98,30 @@ class VhciForwardingService(object): """ @staticmethod - def attachVhci(request_iterator, + def attachVhci( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci', + '/android.emulation.bluetooth.VhciForwardingService/attachVhci', emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/bumble/transport/file.py b/bumble/transport/file.py index 841c62a5..c0c73e85 100644 --- a/bumble/transport/file.py +++ b/bumble/transport/file.py @@ -39,14 +39,12 @@ async def open_file_transport(spec): # Setup reading read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe( - lambda: StreamPacketSource(), - file + lambda: StreamPacketSource(), file ) # Setup writing write_transport, _ = await asyncio.get_running_loop().connect_write_pipe( - lambda: asyncio.BaseProtocol(), - file + lambda: asyncio.BaseProtocol(), file ) packet_sink = StreamPacketSink(write_transport) @@ -57,4 +55,3 @@ async def open_file_transport(spec): file.close() return FileTransport(packet_source, packet_sink) - diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index f74a5357..31456c18 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -44,7 +44,11 @@ async def open_hci_socket_transport(spec): # Create a raw HCI socket try: - hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI) + hci_socket = socket.socket( + socket.AF_BLUETOOTH, + socket.SOCK_RAW | socket.SOCK_NONBLOCK, + socket.BTPROTO_HCI, + ) except AttributeError: # Not supported on this platform logger.info("HCI sockets not supported on this platform") @@ -67,15 +71,26 @@ async def open_hci_socket_transport(spec): raise Exception('Bluetooth HCI sockets not supported on this platform') libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int) libc.bind.restype = ctypes.c_int - bind_address = struct.pack(' the BT USB dongle with vendor=04b4 and product=f901 ''' - USB_RECIPIENT_DEVICE = 0x00 - USB_REQUEST_TYPE_CLASS = 0x01 << 5 - USB_ENDPOINT_EVENTS_IN = 0x81 - USB_ENDPOINT_ACL_IN = 0x82 - USB_ENDPOINT_SCO_IN = 0x83 - USB_ENDPOINT_ACL_OUT = 0x02 + USB_RECIPIENT_DEVICE = 0x00 + USB_REQUEST_TYPE_CLASS = 0x01 << 5 + USB_ENDPOINT_EVENTS_IN = 0x81 + USB_ENDPOINT_ACL_IN = 0x82 + USB_ENDPOINT_SCO_IN = 0x83 + USB_ENDPOINT_ACL_OUT = 0x02 # USB_ENDPOINT_SCO_OUT = 0x03 - USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 - USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 + USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 + USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 - READ_SIZE = 1024 + READ_SIZE = 1024 READ_TIMEOUT = 1000 class UsbPacketSink: def __init__(self, device): - self.device = device - self.thread = threading.Thread(target=self.run) - self.loop = asyncio.get_running_loop() + self.device = device + self.thread = threading.Thread(target=self.run) + self.loop = asyncio.get_running_loop() self.stop_event = None def on_packet(self, packet): @@ -80,9 +80,17 @@ async def open_pyusb_transport(spec): if packet_type == hci.HCI_ACL_DATA_PACKET: self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:]) elif packet_type == hci.HCI_COMMAND_PACKET: - self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:]) + self.device.ctrl_transfer( + USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, + 0, + 0, + 0, + packet[1:], + ) else: - logger.warning(color(f'unsupported packet type {packet_type}', 'red')) + logger.warning( + color(f'unsupported packet type {packet_type}', 'red') + ) except usb.core.USBTimeoutError: logger.warning('USB Write Timeout') except usb.core.USBError as error: @@ -105,17 +113,15 @@ async def open_pyusb_transport(spec): class UsbPacketSource(asyncio.Protocol, ParserSource): def __init__(self, device, sco_enabled): super().__init__() - self.device = device - self.loop = asyncio.get_running_loop() - self.queue = asyncio.Queue() + self.device = device + self.loop = asyncio.get_running_loop() + self.queue = asyncio.Queue() self.event_thread = threading.Thread( - target=self.run, - args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET) + target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET) ) self.event_thread.stop_event = None self.acl_thread = threading.Thread( - target=self.run, - args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET) + target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET) ) self.acl_thread.stop_event = None @@ -124,7 +130,7 @@ async def open_pyusb_transport(spec): if sco_enabled: self.sco_thread = threading.Thread( target=self.run, - args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET) + args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET), ) self.sco_thread.stop_event = None @@ -155,7 +161,7 @@ async def open_pyusb_transport(spec): # Create stop events and wait for them to be signaled self.event_thread.stop_event = asyncio.Event() - self.acl_thread.stop_event = asyncio.Event() + self.acl_thread.stop_event = asyncio.Event() await self.event_thread.stop_event.wait() await self.acl_thread.stop_event.wait() if self.sco_enabled: @@ -197,15 +203,19 @@ async def open_pyusb_transport(spec): # Find the device according to the spec moniker if ':' in spec: vendor_id, product_id = spec.split(':') - device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)) + device = usb.core.find( + idVendor=int(vendor_id, 16), idProduct=int(product_id, 16) + ) else: device_index = int(spec) - devices = list(usb.core.find( - find_all = 1, - bDeviceClass = USB_DEVICE_CLASS_WIRELESS_CONTROLLER, - bDeviceSubClass = USB_DEVICE_SUBCLASS_RF_CONTROLLER, - bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER - )) + devices = list( + usb.core.find( + find_all=1, + bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER, + bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER, + bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, + ) + ) if len(devices) > device_index: device = devices[device_index] else: @@ -273,4 +283,4 @@ async def open_pyusb_transport(spec): packet_source.start() packet_sink.start() - return UsbTransport(device, packet_source, packet_sink) \ No newline at end of file + return UsbTransport(device, packet_source, packet_sink) diff --git a/bumble/transport/serial.py b/bumble/transport/serial.py index b760a29a..94008a59 100644 --- a/bumble/transport/serial.py +++ b/bumble/transport/serial.py @@ -64,9 +64,8 @@ async def open_serial_transport(spec): device, baudrate=speed, rtscts=rtscts, - dsrdtr=dsrdtr + dsrdtr=dsrdtr, ) packet_sink = StreamPacketSink(serial_transport) return Transport(packet_source, packet_sink) - diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py index 68066838..d4c004ee 100644 --- a/bumble/transport/tcp_server.py +++ b/bumble/transport/tcp_server.py @@ -45,7 +45,7 @@ async def open_tcp_server_transport(spec): class TcpServerProtocol: def __init__(self, packet_source, packet_sink): self.packet_source = packet_source - self.packet_sink = packet_sink + self.packet_sink = packet_sink # Called when a new connection is established def connection_made(self, transport): @@ -78,7 +78,7 @@ async def open_tcp_server_transport(spec): local_host, local_port = spec.split(':') packet_source = StreamPacketSource() - packet_sink = TcpServerPacketSink() + packet_sink = TcpServerPacketSink() await asyncio.get_running_loop().create_server( lambda: TcpServerProtocol(packet_source, packet_sink), host=local_host if local_host != '_' else None, diff --git a/bumble/transport/udp.py b/bumble/transport/udp.py index f4c59eab..8f9bec99 100644 --- a/bumble/transport/udp.py +++ b/bumble/transport/udp.py @@ -53,10 +53,13 @@ async def open_udp_transport(spec): local, remote = spec.split(',') local_host, local_port = local.split(':') remote_host, remote_port = remote.split(':') - udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint( + ( + udp_transport, + packet_source, + ) = await asyncio.get_running_loop().create_datagram_endpoint( lambda: UdpPacketSource(), local_addr=(local_host, int(local_port)), - remote_addr=(remote_host, int(remote_port)) + remote_addr=(remote_host, int(remote_port)), ) packet_sink = UdpPacketSink(udp_transport) diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index 1133a5ea..a11dfb05 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -59,33 +59,33 @@ async def open_usb_transport(spec): usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode. ''' - USB_RECIPIENT_DEVICE = 0x00 - USB_REQUEST_TYPE_CLASS = 0x01 << 5 - USB_DEVICE_CLASS_DEVICE = 0x00 - USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 - USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 + USB_RECIPIENT_DEVICE = 0x00 + USB_REQUEST_TYPE_CLASS = 0x01 << 5 + USB_DEVICE_CLASS_DEVICE = 0x00 + USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 + USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 - USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02 - USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03 - USB_ENDPOINT_IN = 0x80 + USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02 + USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03 + USB_ENDPOINT_IN = 0x80 USB_BT_HCI_CLASS_TUPLE = ( USB_DEVICE_CLASS_WIRELESS_CONTROLLER, USB_DEVICE_SUBCLASS_RF_CONTROLLER, - USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER + USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, ) READ_SIZE = 1024 class UsbPacketSink: def __init__(self, device, acl_out): - self.device = device - self.acl_out = acl_out - self.transfer = device.getTransfer() - self.packets = collections.deque() # Queue of packets waiting to be sent - self.loop = asyncio.get_running_loop() + self.device = device + self.acl_out = acl_out + self.transfer = device.getTransfer() + self.packets = collections.deque() # Queue of packets waiting to be sent + self.loop = asyncio.get_running_loop() self.cancel_done = self.loop.create_future() - self.closed = False + self.closed = False def start(self): pass @@ -114,7 +114,9 @@ async def open_usb_transport(spec): elif status == usb1.TRANSFER_CANCELLED: self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) else: - logger.warning(color(f'!!! out transfer not completed: status={status}', 'red')) + logger.warning( + color(f'!!! out transfer not completed: status={status}', 'red') + ) def on_packet_sent_(self): if self.packets: @@ -129,17 +131,18 @@ async def open_usb_transport(spec): packet_type = packet[0] if packet_type == hci.HCI_ACL_DATA_PACKET: self.transfer.setBulk( - self.acl_out, - packet[1:], - callback=self.on_packet_sent + self.acl_out, packet[1:], callback=self.on_packet_sent ) logger.debug('submit ACL') self.transfer.submit() elif packet_type == hci.HCI_COMMAND_PACKET: self.transfer.setControl( - USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, + USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, + 0, + 0, + 0, packet[1:], - callback=self.on_packet_sent + callback=self.on_packet_sent, ) logger.debug('submit COMMAND') self.transfer.submit() @@ -167,17 +170,17 @@ async def open_usb_transport(spec): class UsbPacketSource(asyncio.Protocol, ParserSource): def __init__(self, context, device, acl_in, events_in): super().__init__() - self.context = context - self.device = device - self.acl_in = acl_in - self.events_in = events_in - self.loop = asyncio.get_running_loop() - self.queue = asyncio.Queue() - self.closed = False + self.context = context + self.device = device + self.acl_in = acl_in + self.events_in = events_in + self.loop = asyncio.get_running_loop() + self.queue = asyncio.Queue() + self.closed = False self.event_loop_done = self.loop.create_future() self.cancel_done = { - hci.HCI_EVENT_PACKET: self.loop.create_future(), - hci.HCI_ACL_DATA_PACKET: self.loop.create_future() + hci.HCI_EVENT_PACKET: self.loop.create_future(), + hci.HCI_ACL_DATA_PACKET: self.loop.create_future(), } # Create a thread to process events @@ -190,7 +193,7 @@ async def open_usb_transport(spec): self.events_in, READ_SIZE, callback=self.on_packet_received, - user_data=hci.HCI_EVENT_PACKET + user_data=hci.HCI_EVENT_PACKET, ) self.events_in_transfer.submit() @@ -199,7 +202,7 @@ async def open_usb_transport(spec): self.acl_in, READ_SIZE, callback=self.on_packet_received, - user_data=hci.HCI_ACL_DATA_PACKET + user_data=hci.HCI_ACL_DATA_PACKET, ) self.acl_in_transfer.submit() @@ -212,13 +215,20 @@ async def open_usb_transport(spec): # logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}') if status == usb1.TRANSFER_COMPLETED: - packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()] + packet = ( + bytes([packet_type]) + + transfer.getBuffer()[: transfer.getActualLength()] + ) self.loop.call_soon_threadsafe(self.queue.put_nowait, packet) elif status == usb1.TRANSFER_CANCELLED: - self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None) + self.loop.call_soon_threadsafe( + self.cancel_done[packet_type].set_result, None + ) return else: - logger.warning(color(f'!!! transfer not completed: status={status}', 'red')) + logger.warning( + color(f'!!! transfer not completed: status={status}', 'red') + ) # Re-submit the transfer so we can receive more data transfer.submit() @@ -233,7 +243,10 @@ async def open_usb_transport(spec): def run(self): logger.debug('starting USB event loop') - while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted(): + while ( + self.events_in_transfer.isSubmitted() + or self.acl_in_transfer.isSubmitted() + ): try: self.context.handleEvents() except usb1.USBErrorInterrupted: @@ -253,11 +266,15 @@ async def open_usb_transport(spec): packet_type = transfer.getUserData() try: transfer.cancel() - logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...') + logger.debug( + f'waiting for IN[{packet_type}] transfer cancellation to be done...' + ) await self.cancel_done[packet_type] logger.debug(f'IN[{packet_type}] transfer cancellation done') except usb1.USBError: - logger.debug(f'IN[{packet_type}] transfer likely already completed') + logger.debug( + f'IN[{packet_type}] transfer likely already completed' + ) # Wait for the thread to terminate await self.event_loop_done @@ -265,8 +282,8 @@ async def open_usb_transport(spec): class UsbTransport(Transport): def __init__(self, context, device, interface, setting, source, sink): super().__init__(source, sink) - self.context = context - self.device = device + self.context = context + self.device = device self.interface = interface # Get exclusive access @@ -315,9 +332,9 @@ async def open_usb_transport(spec): except usb1.USBError: device_serial_number = None if ( - device.getVendorID() == int(vendor_id, 16) and - device.getProductID() == int(product_id, 16) and - (serial_number is None or serial_number == device_serial_number) + device.getVendorID() == int(vendor_id, 16) + and device.getProductID() == int(product_id, 16) + and (serial_number is None or serial_number == device_serial_number) ): if device_index == 0: found = device @@ -328,8 +345,11 @@ async def open_usb_transport(spec): # Look for a compatible device by index def device_is_bluetooth_hci(device): # Check if the device class indicates a match - if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == \ - USB_BT_HCI_CLASS_TUPLE: + if ( + device.getDeviceClass(), + device.getDeviceSubClass(), + device.getDeviceProtocol(), + ) == USB_BT_HCI_CLASS_TUPLE: return True # If the device class is 'Device', look for a matching interface @@ -337,8 +357,11 @@ async def open_usb_transport(spec): for configuration in device: for interface in configuration: for setting in interface: - if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == \ - USB_BT_HCI_CLASS_TUPLE: + if ( + setting.getClass(), + setting.getSubClass(), + setting.getProtocol(), + ) == USB_BT_HCI_CLASS_TUPLE: return True return False @@ -366,38 +389,52 @@ async def open_usb_transport(spec): setting = None for setting in interface: if ( - not forced_mode and - (setting.getClass(), setting.getSubClass(), setting.getProtocol()) != USB_BT_HCI_CLASS_TUPLE + not forced_mode + and ( + setting.getClass(), + setting.getSubClass(), + setting.getProtocol(), + ) + != USB_BT_HCI_CLASS_TUPLE ): continue events_in = None - acl_in = None - acl_out = None + acl_in = None + acl_out = None for endpoint in setting: attributes = endpoint.getAttributes() - address = endpoint.getAddress() + address = endpoint.getAddress() if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK: if address & USB_ENDPOINT_IN and acl_in is None: acl_in = address elif acl_out is None: acl_out = address - elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT: + elif ( + attributes & 0x03 + == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT + ): if address & USB_ENDPOINT_IN and events_in is None: events_in = address # Return if we found all 3 endpoints - if acl_in is not None and acl_out is not None and events_in is not None: + if ( + acl_in is not None + and acl_out is not None + and events_in is not None + ): return ( configuration_index + 1, setting.getNumber(), setting.getAlternateSetting(), acl_in, acl_out, - events_in + events_in, ) else: - logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}') + logger.debug( + f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}' + ) endpoints = find_endpoints(found) if endpoints is None: @@ -437,7 +474,7 @@ async def open_usb_transport(spec): logger.warning('failed to set configuration') source = UsbPacketSource(context, device, acl_in, events_in) - sink = UsbPacketSink(device, acl_out) + sink = UsbPacketSink(device, acl_out) return UsbTransport(context, device, interface, setting, source, sink) except usb1.USBError as error: logger.warning(color(f'!!! failed to open USB device: {error}', 'red')) diff --git a/bumble/transport/vhci.py b/bumble/transport/vhci.py index 572c31d2..ec61ab43 100644 --- a/bumble/transport/vhci.py +++ b/bumble/transport/vhci.py @@ -33,7 +33,7 @@ async def open_vhci_transport(spec): path at /dev/vhci), or the path of a VHCI device ''' - HCI_VENDOR_PKT = 0xff + HCI_VENDOR_PKT = 0xFF HCI_BREDR = 0x00 # Controller type # Open the VHCI device @@ -56,4 +56,3 @@ async def open_vhci_transport(spec): transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) return transport - diff --git a/bumble/transport/ws_client.py b/bumble/transport/ws_client.py index 9ee7e493..85f6e88c 100644 --- a/bumble/transport/ws_client.py +++ b/bumble/transport/ws_client.py @@ -43,7 +43,7 @@ async def open_ws_client_transport(spec): transport = PumpedTransport( PumpedPacketSource(websocket.recv), PumpedPacketSink(websocket.send), - websocket.close + websocket.close, ) transport.start() return transport diff --git a/bumble/transport/ws_server.py b/bumble/transport/ws_server.py index 3b2d15e6..98d13629 100644 --- a/bumble/transport/ws_server.py +++ b/bumble/transport/ws_server.py @@ -41,8 +41,8 @@ async def open_ws_server_transport(spec): class WsServerTransport(Transport): def __init__(self): - source = ParserSource() - sink = PumpedPacketSink(self.send_packet) + source = ParserSource() + sink = PumpedPacketSink(self.send_packet) self.connection = asyncio.get_running_loop().create_future() super().__init__(source, sink) @@ -50,14 +50,16 @@ async def open_ws_server_transport(spec): async def serve(self, local_host, local_port): self.sink.start() self.server = await websockets.serve( - ws_handler = self.on_connection, - host = local_host if local_host != '_' else None, - port = int(local_port) + ws_handler=self.on_connection, + host=local_host if local_host != '_' else None, + port=int(local_port), ) logger.debug(f'websocket server ready on port {local_port}') async def on_connection(self, connection): - logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}') + logger.debug( + f'new connection on {connection.local_address} from {connection.remote_address}' + ) self.connection.set_result(connection) try: async for packet in connection: diff --git a/bumble/utils.py b/bumble/utils.py index 5d8ab954..92cef633 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -34,6 +34,7 @@ logger = logging.getLogger(__name__) def setup_event_forwarding(emitter, forwarder, event_name): def emit(*args, **kwargs): forwarder.emit(event_name, *args, **kwargs) + emitter.on(event_name, emit) @@ -44,6 +45,7 @@ def composite_listener(cls): registers/deregisters all methods named `on_` as a listener for the event with an emitter. """ + def register(self, emitter): for method_name in dir(cls): if method_name.startswith('on_'): @@ -54,7 +56,7 @@ def composite_listener(cls): if method_name.startswith('on_'): emitter.remove_listener(method_name[3:], getattr(self, method_name)) - cls._bumble_register_composite = register + cls._bumble_register_composite = register cls._bumble_deregister_composite = deregister return cls @@ -110,7 +112,9 @@ class AsyncRunner: try: await item except Exception as error: - logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}') + logger.warning( + f'{color("!!! Exception in work queue:", "red")} {error}' + ) # Shared default queue default_queue = WorkQueue() @@ -131,7 +135,9 @@ class AsyncRunner: try: await coroutine except Exception: - logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}') + logger.warning( + f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}' + ) asyncio.create_task(run()) else: @@ -150,18 +156,26 @@ class FlowControlAsyncPipe: paused (by calling a function passed in when the pipe is created) if the amount of queued data exceeds a specified threshold. """ - def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0): - self.pause_source = pause_source + + def __init__( + self, + pause_source, + resume_source, + write_to_sink=None, + drain_sink=None, + threshold=0, + ): + self.pause_source = pause_source self.resume_source = resume_source self.write_to_sink = write_to_sink - self.drain_sink = drain_sink - self.threshold = threshold - self.queue = collections.deque() # Queue of packets - self.queued_bytes = 0 # Number of bytes in the queue + self.drain_sink = drain_sink + self.threshold = threshold + self.queue = collections.deque() # Queue of packets + self.queued_bytes = 0 # Number of bytes in the queue self.ready_to_pump = asyncio.Event() - self.paused = False + self.paused = False self.source_paused = False - self.pump_task = None + self.pump_task = None def start(self): if self.pump_task is None: diff --git a/examples/async_runner.py b/examples/async_runner.py index d0d1a12b..9e71899c 100644 --- a/examples/async_runner.py +++ b/examples/async_runner.py @@ -80,6 +80,7 @@ async def main(): await my_work_queue2.run() print("MAIN: end (should never get here)") + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/battery_client.py b/examples/battery_client.py index f545f129..9f8641e2 100644 --- a/examples/battery_client.py +++ b/examples/battery_client.py @@ -55,7 +55,9 @@ async def main(): # Subscribe to and read the battery level if battery_service.battery_level: await battery_service.battery_level.subscribe( - lambda value: print(f'{color("Battery Level Update:", "green")} {value}') + lambda value: print( + f'{color("Battery Level Update:", "green")} {value}' + ) ) value = await battery_service.battery_level.read_value() print(f'{color("Initial Battery Level:", "green")} {value}') @@ -64,5 +66,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/battery_server.py b/examples/battery_server.py index 68685270..b7f941f1 100644 --- a/examples/battery_server.py +++ b/examples/battery_server.py @@ -44,11 +44,19 @@ async def main(): # Set the advertising data device.advertising_data = bytes( - AdvertisingData([ - (AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Battery', 'utf-8')), - (AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(battery_service.uuid)), - (AdvertisingData.APPEARANCE, struct.pack(' ') + print( + 'Usage: device_information_client.py ' + ) print('example: device_information_client.py usb:0 E1:CA:72:48:C4:E8') return @@ -49,7 +51,9 @@ async def main(): # Discover the Device Information service peer = Peer(connection) print('=== Discovering Device Information Service') - device_information_service = await peer.discover_service_and_create_proxy(DeviceInformationServiceProxy) + device_information_service = await peer.discover_service_and_create_proxy( + DeviceInformationServiceProxy + ) # Check that the service was found if device_information_service is None: @@ -58,23 +62,52 @@ async def main(): # Read and print the fields if device_information_service.manufacturer_name is not None: - print(color('Manufacturer Name: ', 'green'), await device_information_service.manufacturer_name.read_value()) + print( + color('Manufacturer Name: ', 'green'), + await device_information_service.manufacturer_name.read_value(), + ) if device_information_service.model_number is not None: - print(color('Model Number: ', 'green'), await device_information_service.model_number.read_value()) + print( + color('Model Number: ', 'green'), + await device_information_service.model_number.read_value(), + ) if device_information_service.serial_number is not None: - print(color('Serial Number: ', 'green'), await device_information_service.serial_number.read_value()) + print( + color('Serial Number: ', 'green'), + await device_information_service.serial_number.read_value(), + ) if device_information_service.hardware_revision is not None: - print(color('Hardware Revision: ', 'green'), await device_information_service.hardware_revision.read_value()) + print( + color('Hardware Revision: ', 'green'), + await device_information_service.hardware_revision.read_value(), + ) if device_information_service.firmware_revision is not None: - print(color('Firmware Revision: ', 'green'), await device_information_service.firmware_revision.read_value()) + print( + color('Firmware Revision: ', 'green'), + await device_information_service.firmware_revision.read_value(), + ) if device_information_service.software_revision is not None: - print(color('Software Revision: ', 'green'), await device_information_service.software_revision.read_value()) + print( + color('Software Revision: ', 'green'), + await device_information_service.software_revision.read_value(), + ) if device_information_service.system_id is not None: - print(color('System ID: ', 'green'), await device_information_service.system_id.read_value()) - if device_information_service.ieee_regulatory_certification_data_list is not None: - print(color('Regulatory Certification:', 'green'), (await device_information_service.ieee_regulatory_certification_data_list.read_value()).hex()) + print( + color('System ID: ', 'green'), + await device_information_service.system_id.read_value(), + ) + if ( + device_information_service.ieee_regulatory_certification_data_list + is not None + ): + print( + color('Regulatory Certification:', 'green'), + ( + await device_information_service.ieee_regulatory_certification_data_list.read_value() + ).hex(), + ) # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/device_information_server.py b/examples/device_information_server.py index 9c3b6b14..d437caee 100644 --- a/examples/device_information_server.py +++ b/examples/device_information_server.py @@ -39,21 +39,26 @@ async def main(): # Add a Device Information Service to the GATT sever device_information_service = DeviceInformationService( - manufacturer_name = 'ACME', - model_number = 'AB-102', - serial_number = '7654321', - hardware_revision = '1.1.3', - software_revision = '2.5.6', - system_id = (0x123456, 0x8877665544) + manufacturer_name='ACME', + model_number='AB-102', + serial_number='7654321', + hardware_revision='1.1.3', + software_revision='2.5.6', + system_id=(0x123456, 0x8877665544), ) device.add_service(device_information_service) # Set the advertising data device.advertising_data = bytes( - AdvertisingData([ - (AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Device', 'utf-8')), - (AdvertisingData.APPEARANCE, struct.pack('= ord('a') and code <= ord('z'): hid_code = 0x04 + code - ord('a') - input_report_characteristic.value = bytes([0, 0, hid_code, 0, 0, 0, 0, 0]) - await device.notify_subscribers(input_report_characteristic) + input_report_characteristic.value = bytes( + [0, 0, hid_code, 0, 0, 0, 0, 0] + ) + await device.notify_subscribers( + input_report_characteristic + ) elif message_type == 'keyup': - input_report_characteristic.value = bytes.fromhex('0000000000000000') + input_report_characteristic.value = bytes.fromhex( + '0000000000000000' + ) await device.notify_subscribers(input_report_characteristic) except websockets.exceptions.ConnectionClosedOK: pass + await websockets.serve(serve, 'localhost', 8989) await asyncio.get_event_loop().create_future() else: @@ -321,7 +400,9 @@ async def keyboard_device(device, command): # Keypress for the letter keycode = 0x04 + letter - 0x61 - input_report_characteristic.value = bytes([0, 0, keycode, 0, 0, 0, 0, 0]) + input_report_characteristic.value = bytes( + [0, 0, keycode, 0, 0, 0, 0, 0] + ) await device.notify_subscribers(input_report_characteristic) # Key release @@ -335,10 +416,16 @@ async def main(): print('Usage: python keyboard.py ') print(' where is one of:') print(' connect
(run a keyboard host, connecting to a keyboard)') - print(' web (run a keyboard with keypress input from a web page, see keyboard.html') - print(' sim (run a keyboard simulation, emitting a canned sequence of keystrokes') + print( + ' web (run a keyboard with keypress input from a web page, see keyboard.html' + ) + print( + ' sim (run a keyboard simulation, emitting a canned sequence of keystrokes' + ) print('example: python keyboard.py keyboard.json usb:0 sim') - print('example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5') + print( + 'example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5' + ) return async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): @@ -355,5 +442,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_a2dp_info.py b/examples/run_a2dp_info.py index 0d6f66bb..cf63adb0 100644 --- a/examples/run_a2dp_info.py +++ b/examples/run_a2dp_info.py @@ -27,12 +27,9 @@ from bumble.core import ( BT_BR_EDR_TRANSPORT, BT_AVDTP_PROTOCOL_ID, BT_AUDIO_SINK_SERVICE, - BT_L2CAP_PROTOCOL_ID -) -from bumble.avdtp import ( - Protocol as AVDTP_Protocol, - find_avdtp_service_with_connection + BT_L2CAP_PROTOCOL_ID, ) +from bumble.avdtp import Protocol as AVDTP_Protocol, find_avdtp_service_with_connection from bumble.a2dp import make_audio_source_service_sdp_records from bumble.sdp import ( Client as SDP_Client, @@ -40,7 +37,7 @@ from bumble.sdp import ( DataElement, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, ) @@ -48,7 +45,9 @@ from bumble.sdp import ( def sdp_records(): service_record_handle = 0x00010001 return { - service_record_handle: make_audio_source_service_sdp_records(service_record_handle) + service_record_handle: make_audio_source_service_sdp_records( + service_record_handle + ) } @@ -64,8 +63,8 @@ async def find_a2dp_service(device, connection): [ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID - ] + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + ], ) print(color('==================================', 'blue')) @@ -78,8 +77,7 @@ async def find_a2dp_service(device, connection): # Service classes service_class_id_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID + attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID ) if service_class_id_list: if service_class_id_list.value: @@ -89,8 +87,7 @@ async def find_a2dp_service(device, connection): # Protocol info protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID + attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID ) if protocol_descriptor_list: print(color(' Protocol:', 'green')) @@ -103,18 +100,24 @@ async def find_a2dp_service(device, connection): if len(protocol_descriptor.value) >= 2: avdtp_version_major = protocol_descriptor.value[1].value >> 8 avdtp_version_minor = protocol_descriptor.value[1].value & 0xFF - print(f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}') + print( + f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}' + ) service_version = (avdtp_version_major, avdtp_version_minor) # Profile info bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + attribute_list, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID ) if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list.value: - if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: - bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value + if ( + bluetooth_profile_descriptor_list.value[0].type + == DataElement.SEQUENCE + ): + bluetooth_profile_descriptors = ( + bluetooth_profile_descriptor_list.value + ) else: # Sometimes, instead of a list of lists, we just find a list. Fix that bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] @@ -123,7 +126,9 @@ async def find_a2dp_service(device, connection): for bluetooth_profile_descriptor in bluetooth_profile_descriptors: version_major = bluetooth_profile_descriptor.value[1].value >> 8 version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF - print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') + print( + f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}' + ) await sdp_client.disconnect() return service_version @@ -184,5 +189,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_a2dp_sink.py b/examples/run_a2dp_sink.py index bc193a1d..aedf0a49 100644 --- a/examples/run_a2dp_sink.py +++ b/examples/run_a2dp_sink.py @@ -28,7 +28,7 @@ from bumble.avdtp import ( AVDTP_AUDIO_MEDIA_TYPE, Protocol, Listener, - MediaCodecCapabilities + MediaCodecCapabilities, ) from bumble.a2dp import ( make_audio_sink_service_sdp_records, @@ -39,19 +39,19 @@ from bumble.a2dp import ( SBC_LOUDNESS_ALLOCATION_METHOD, SBC_STEREO_CHANNEL_MODE, SBC_JOINT_STEREO_CHANNEL_MODE, - SbcMediaCodecInformation + SbcMediaCodecInformation, ) -Context = { - 'output': None -} +Context = {'output': None} # ----------------------------------------------------------------------------- def sdp_records(): service_record_handle = 0x00010001 return { - service_record_handle: make_audio_sink_service_sdp_records(service_record_handle) + service_record_handle: make_audio_sink_service_sdp_records( + service_record_handle + ) } @@ -59,22 +59,25 @@ def sdp_records(): def codec_capabilities(): # NOTE: this shouldn't be hardcoded, but passed on the command line instead return MediaCodecCapabilities( - media_type = AVDTP_AUDIO_MEDIA_TYPE, - media_codec_type = A2DP_SBC_CODEC_TYPE, - media_codec_information = SbcMediaCodecInformation.from_lists( - sampling_frequencies = [48000, 44100, 32000, 16000], - channel_modes = [ + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_lists( + sampling_frequencies=[48000, 44100, 32000, 16000], + channel_modes=[ SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE, - SBC_JOINT_STEREO_CHANNEL_MODE + SBC_JOINT_STEREO_CHANNEL_MODE, ], - block_lengths = [4, 8, 12, 16], - subbands = [4, 8], - allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD], - minimum_bitpool_value = 2, - maximum_bitpool_value = 53 - ) + block_lengths=[4, 8, 12, 16], + subbands=[4, 8], + allocation_methods=[ + SBC_LOUDNESS_ALLOCATION_METHOD, + SBC_SNR_ALLOCATION_METHOD, + ], + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), ) @@ -87,10 +90,10 @@ def on_avdtp_connection(server): # ----------------------------------------------------------------------------- def on_rtp_packet(packet): - header = packet.payload[0] - fragmented = header >> 7 - start = (header >> 6) & 0x01 - last = (header >> 5) & 0x01 + header = packet.payload[0] + fragmented = header >> 7 + start = (header >> 6) & 0x01 + last = (header >> 5) & 0x01 number_of_frames = header & 0x0F if fragmented: @@ -104,7 +107,9 @@ def on_rtp_packet(packet): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 4: - print('Usage: run_a2dp_sink.py []') + print( + 'Usage: run_a2dp_sink.py []' + ) print('example: run_a2dp_sink.py classic1.json usb:0 output.sbc') return @@ -133,7 +138,9 @@ async def main(): # Connect to the source target_address = sys.argv[4] print(f'=== Connecting to {target_address}...') - connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) + connection = await device.connect( + target_address, transport=BT_BR_EDR_TRANSPORT + ) print(f'=== Connected to {connection.peer_address}!') # Request authentication @@ -159,5 +166,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_a2dp_source.py b/examples/run_a2dp_source.py index 45abad96..708a1986 100644 --- a/examples/run_a2dp_source.py +++ b/examples/run_a2dp_source.py @@ -30,7 +30,7 @@ from bumble.avdtp import ( MediaCodecCapabilities, MediaPacketPump, Protocol, - Listener + Listener, ) from bumble.a2dp import ( SBC_JOINT_STEREO_CHANNEL_MODE, @@ -38,7 +38,7 @@ from bumble.a2dp import ( make_audio_source_service_sdp_records, A2DP_SBC_CODEC_TYPE, SbcMediaCodecInformation, - SbcPacketSource + SbcPacketSource, ) @@ -46,7 +46,9 @@ from bumble.a2dp import ( def sdp_records(): service_record_handle = 0x00010001 return { - service_record_handle: make_audio_source_service_sdp_records(service_record_handle) + service_record_handle: make_audio_source_service_sdp_records( + service_record_handle + ) } @@ -54,23 +56,25 @@ def sdp_records(): def codec_capabilities(): # NOTE: this shouldn't be hardcoded, but should be inferred from the input file instead return MediaCodecCapabilities( - media_type = AVDTP_AUDIO_MEDIA_TYPE, - media_codec_type = A2DP_SBC_CODEC_TYPE, - media_codec_information = SbcMediaCodecInformation.from_discrete_values( - sampling_frequency = 44100, - channel_mode = SBC_JOINT_STEREO_CHANNEL_MODE, - block_length = 16, - subbands = 8, - allocation_method = SBC_LOUDNESS_ALLOCATION_METHOD, - minimum_bitpool_value = 2, - maximum_bitpool_value = 53 - ) + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_discrete_values( + sampling_frequency=44100, + channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE, + block_length=16, + subbands=8, + allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), ) # ----------------------------------------------------------------------------- def on_avdtp_connection(read_function, protocol): - packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities()) + packet_source = SbcPacketSource( + read_function, protocol.l2cap_channel.mtu, codec_capabilities() + ) packet_pump = MediaPacketPump(packet_source.packets) protocol.add_source(packet_source.codec_capabilities, packet_pump) @@ -83,14 +87,18 @@ async def stream_packets(read_function, protocol): print('@@@', endpoint) # Select a sink - sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE) + sink = protocol.find_remote_sink_by_codec( + AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE + ) if sink is None: print(color('!!! no SBC sink found', 'red')) return print(f'### Selected sink: {sink.seid}') # Stream the packets - packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities()) + packet_source = SbcPacketSource( + read_function, protocol.l2cap_channel.mtu, codec_capabilities() + ) packet_pump = MediaPacketPump(packet_source.packets) source = protocol.add_source(packet_source.codec_capabilities, packet_pump) stream = await protocol.create_stream(source, sink) @@ -107,8 +115,12 @@ async def stream_packets(read_function, protocol): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 4: - print('Usage: run_a2dp_source.py []') - print('example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8') + print( + 'Usage: run_a2dp_source.py []' + ) + print( + 'example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8' + ) return print('<<< connecting to HCI...') @@ -134,7 +146,9 @@ async def main(): # Connect to a peer target_address = sys.argv[4] print(f'=== Connecting to {target_address}...') - connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) + connection = await device.connect( + target_address, transport=BT_BR_EDR_TRANSPORT + ) print(f'=== Connected to {connection.peer_address}!') # Request authentication @@ -148,7 +162,9 @@ async def main(): print('*** Encryption on') # Look for an A2DP service - avdtp_version = await find_avdtp_service_with_connection(device, connection) + avdtp_version = await find_avdtp_service_with_connection( + device, connection + ) if not avdtp_version: print(color('!!! no A2DP service found')) return @@ -161,7 +177,9 @@ async def main(): else: # Create a listener to wait for AVDTP connections listener = Listener(Listener.create_registrar(device), version=(1, 2)) - listener.on('connection', lambda protocol: on_avdtp_connection(read, protocol)) + listener.on( + 'connection', lambda protocol: on_avdtp_connection(read, protocol) + ) # Become connectable and wait for a connection await device.set_discoverable(True) @@ -171,5 +189,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_advertiser.py b/examples/run_advertiser.py index e54bc37e..c535c661 100644 --- a/examples/run_advertiser.py +++ b/examples/run_advertiser.py @@ -30,7 +30,9 @@ from bumble.transport import open_transport_or_link # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: - print('Usage: run_advertiser.py [type] [address]') + print( + 'Usage: run_advertiser.py [type] [address]' + ) print('example: run_advertiser.py device1.json usb:0') return @@ -56,6 +58,7 @@ async def main(): await device.start_advertising(advertising_type=advertising_type, target=target) await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_asha_sink.py b/examples/run_asha_sink.py index bebb5de7..46b8dbb5 100644 --- a/examples/run_asha_sink.py +++ b/examples/run_asha_sink.py @@ -25,28 +25,34 @@ from bumble.core import AdvertisingData from bumble.device import Device from bumble.transport import open_transport_or_link from bumble.hci import UUID -from bumble.gatt import ( - Service, - Characteristic, - CharacteristicValue -) +from bumble.gatt import Service, Characteristic, CharacteristicValue # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid') -ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties') -ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint') -ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus') -ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume') -ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT') +ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid') +ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID( + '6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties' +) +ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID( + 'f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint' +) +ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID( + '38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus' +) +ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume') +ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID( + '2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT' +) # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) != 4: - print('Usage: python run_asha_sink.py ') + print( + 'Usage: python run_asha_sink.py ' + ) print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722') return @@ -62,14 +68,18 @@ async def main(): if opcode == 1: # Start audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]] - print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}') + print( + f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}' + ) elif opcode == 2: print('### STOP') elif opcode == 3: print(f'### STATUS: connected={value[1]}') # Respond with a status - asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True)) + asyncio.create_task( + device.notify_subscribers(audio_status_characteristic, force=True) + ) # Handler for volume control def on_volume_write(connection, value): @@ -91,63 +101,91 @@ async def main(): ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - bytes([ - 0x01, # Version - 0x00, # Device Capabilities [Left, Monaural] - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId - 0x01, # Feature Map [LE CoC audio output streaming supported] - 0x00, 0x00, # Render Delay - 0x00, 0x00, # RFU - 0x02, 0x00 # Codec IDs [G.722 at 16 kHz] - ]) + bytes( + [ + 0x01, # Version + 0x00, # Device Capabilities [Left, Monaural] + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + 0x06, + 0x07, + 0x08, # HiSyncId + 0x01, # Feature Map [LE CoC audio output streaming supported] + 0x00, + 0x00, # Render Delay + 0x00, + 0x00, # RFU + 0x02, + 0x00, # Codec IDs [G.722 at 16 kHz] + ] + ), ) audio_control_point_characteristic = Characteristic( ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC, Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.WRITEABLE, - CharacteristicValue(write=on_audio_control_point_write) + CharacteristicValue(write=on_audio_control_point_write), ) audio_status_characteristic = Characteristic( ASHA_AUDIO_STATUS_CHARACTERISTIC, Characteristic.READ | Characteristic.NOTIFY, Characteristic.READABLE, - bytes([0]) + bytes([0]), ) volume_characteristic = Characteristic( ASHA_VOLUME_CHARACTERISTIC, Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.WRITEABLE, - CharacteristicValue(write=on_volume_write) + CharacteristicValue(write=on_volume_write), ) le_psm_out_characteristic = Characteristic( ASHA_LE_PSM_OUT_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, - struct.pack(' ') - print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8') + print( + 'Usage: run_classic_connect.py ' + ) + print( + 'example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8' + ) return print('<<< connecting to HCI...') @@ -53,32 +61,49 @@ async def main(): await sdp_client.connect(connection) # List all services in the root browse group - service_record_handles = await sdp_client.search_services([SDP_PUBLIC_BROWSE_ROOT]) + service_record_handles = await sdp_client.search_services( + [SDP_PUBLIC_BROWSE_ROOT] + ) print(color('\n==================================', 'blue')) print(color('SERVICES:', 'yellow'), service_record_handles) # For each service in the root browse group, get all its attributes for service_record_handle in service_record_handles: - attributes = await sdp_client.get_attributes(service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE]) + attributes = await sdp_client.get_attributes( + service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE] + ) print(color(f'SERVICE {service_record_handle:04X} attributes:', 'yellow')) for attribute in attributes: print(' ', attribute.to_string(color=True)) # Search for services with an L2CAP service attribute - search_result = await sdp_client.search_attributes([BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE]) + search_result = await sdp_client.search_attributes( + [BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE] + ) print(color('\n==================================', 'blue')) print(color('SEARCH RESULTS:', 'yellow')) for attribute_list in search_result: print(color('SERVICE:', 'green')) - print(' ' + '\n '.join([attribute.to_string(color=True) for attribute in attribute_list])) + print( + ' ' + + '\n '.join( + [attribute.to_string(color=True) for attribute in attribute_list] + ) + ) await sdp_client.disconnect() await hci_source.wait_for_termination() # Connect to a peer target_addresses = sys.argv[3:] - await asyncio.wait([asyncio.create_task(connect(target_address)) for target_address in target_addresses]) + await asyncio.wait( + [ + asyncio.create_task(connect(target_address)) + for target_address in target_addresses + ] + ) + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_classic_discoverable.py b/examples/run_classic_discoverable.py index 5cdbc275..076a9ec5 100644 --- a/examples/run_classic_discoverable.py +++ b/examples/run_classic_discoverable.py @@ -30,48 +30,62 @@ from bumble.sdp import ( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) from bumble.core import ( BT_AUDIO_SINK_SERVICE, BT_L2CAP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID, - BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE + BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, ) # ----------------------------------------------------------------------------- SDP_SERVICE_RECORDS = { 0x00010001: [ - ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)), - ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) - ])), + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(0x00010001), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), ServiceAttribute( SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]) + DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]), ), ServiceAttribute( SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID), - DataElement.unsigned_integer_16(25) - ]), - DataElement.sequence([ - DataElement.uuid(BT_AVDTP_PROTOCOL_ID), - DataElement.unsigned_integer_16(256) - ]) - ]) + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(25), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVDTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(256), + ] + ), + ] + ), ), ServiceAttribute( SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(256) - ]) - ]) - ) + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(256), + ] + ) + ] + ), + ), ] } @@ -99,6 +113,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_classic_discovery.py b/examples/run_classic_discovery.py index b0ab5eec..cd3240e2 100644 --- a/examples/run_classic_discovery.py +++ b/examples/run_classic_discovery.py @@ -29,13 +29,23 @@ from bumble.core import DeviceClass # ----------------------------------------------------------------------------- class DiscoveryListener(Device.Listener): def on_inquiry_result(self, address, class_of_device, eir_data, rssi): - service_classes, major_device_class, minor_device_class = DeviceClass.split_class_of_device(class_of_device) + ( + service_classes, + major_device_class, + minor_device_class, + ) = DeviceClass.split_class_of_device(class_of_device) separator = '\n ' print(f'>>> {color(address, "yellow")}:') print(f' Device Class (raw): {class_of_device:06X}') - print(f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}') - print(f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}') - print(f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}') + print( + f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}' + ) + print( + f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}' + ) + print( + f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}' + ) print(f' RSSI: {rssi}') if eir_data.ad_structures: print(f' {eir_data.to_string(separator)}') @@ -59,6 +69,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_connect_and_encrypt.py b/examples/run_connect_and_encrypt.py index 0ee868ad..ed47686d 100644 --- a/examples/run_connect_and_encrypt.py +++ b/examples/run_connect_and_encrypt.py @@ -27,8 +27,12 @@ from bumble.transport import open_transport_or_link # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: - print('Usage: run_connect_and_encrypt.py ') - print('example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8') + print( + 'Usage: run_connect_and_encrypt.py ' + ) + print( + 'example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8' + ) return print('<<< connecting to HCI...') @@ -53,6 +57,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_controller.py b/examples/run_controller.py index d8295d4f..3298d349 100644 --- a/examples/run_controller.py +++ b/examples/run_controller.py @@ -32,8 +32,12 @@ from bumble.transport import open_transport_or_link # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) != 4: - print('Usage: run_controller.py ') - print('example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333') + print( + 'Usage: run_controller.py ' + ) + print( + 'example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333' + ) return print('>>> connecting to HCI...') @@ -44,11 +48,13 @@ async def main(): link = LocalLink() # Create a first controller using the packet source/sink as its host interface - controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link) + controller1 = Controller( + 'C1', host_source=hci_source, host_sink=hci_sink, link=link + ) controller1.random_address = sys.argv[1] # Create a second controller using the same link - controller2 = Controller('C2', link = link) + controller2 = Controller('C2', link=link) # Create a host for the second controller host = Host() @@ -59,17 +65,21 @@ async def main(): device.host = host # Add some basic services to the device's GATT server - descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') + descriptor = Descriptor( + GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, + Descriptor.READABLE, + 'My Description', + ) manufacturer_name_characteristic = Characteristic( GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, "Fitbit", - [descriptor] + [descriptor], + ) + device_info_service = Service( + GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic] ) - device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [ - manufacturer_name_characteristic - ]) device.add_service(device_info_service) # Debug print @@ -82,6 +92,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_controller_with_scanner.py b/examples/run_controller_with_scanner.py index 18ba2743..6a70a02b 100644 --- a/examples/run_controller_with_scanner.py +++ b/examples/run_controller_with_scanner.py @@ -37,7 +37,9 @@ class ScannerListener(Device.Listener): else: type_color = 'cyan' - print(f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]: RSSI={advertisement.rssi}, {advertisement.data}') + print( + f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]: RSSI={advertisement.rssi}, {advertisement.data}' + ) # ----------------------------------------------------------------------------- @@ -55,20 +57,25 @@ async def main(): link = LocalLink() # Create a first controller using the packet source/sink as its host interface - controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link) + controller1 = Controller( + 'C1', host_source=hci_source, host_sink=hci_sink, link=link + ) controller1.address = 'E0:E1:E2:E3:E4:E5' # Create a second controller using the same link - controller2 = Controller('C2', link = link) + controller2 = Controller('C2', link=link) # Create a device with a scanner listener - device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2) + device = Device.with_hci( + 'Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2 + ) device.listener = ScannerListener() await device.power_on() await device.start_scanning() await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_gatt_client.py b/examples/run_gatt_client.py index 5af86fb1..59f57424 100644 --- a/examples/run_gatt_client.py +++ b/examples/run_gatt_client.py @@ -70,7 +70,9 @@ class Listener(Device.Listener): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: - print('Usage: run_gatt_client.py []') + print( + 'Usage: run_gatt_client.py []' + ) print('example: run_gatt_client.py device1.json usb:0 E1:CA:72:48:C4:E8') return @@ -93,6 +95,7 @@ async def main(): await asyncio.get_running_loop().create_future() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_gatt_client_and_server.py b/examples/run_gatt_client_and_server.py index 940b1a83..6586ca4d 100644 --- a/examples/run_gatt_client_and_server.py +++ b/examples/run_gatt_client_and_server.py @@ -32,7 +32,7 @@ from bumble.gatt import ( show_services, GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, - GATT_DEVICE_INFORMATION_SERVICE + GATT_DEVICE_INFORMATION_SERVICE, ) @@ -48,32 +48,36 @@ async def main(): link = LocalLink() # Setup a stack for the client - client_controller = Controller("client controller", link = link) + client_controller = Controller("client controller", link=link) client_host = Host() client_host.controller = client_controller - client_device = Device("client", address = 'F0:F1:F2:F3:F4:F5', host = client_host) + client_device = Device("client", address='F0:F1:F2:F3:F4:F5', host=client_host) await client_device.power_on() # Setup a stack for the server - server_controller = Controller("server controller", link = link) + server_controller = Controller("server controller", link=link) server_host = Host() server_host.controller = server_controller - server_device = Device("server", address = 'F6:F7:F8:F9:FA:FB', host = server_host) + server_device = Device("server", address='F6:F7:F8:F9:FA:FB', host=server_host) server_device.listener = ServerListener() await server_device.power_on() # Add a few entries to the device's GATT server - descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') + descriptor = Descriptor( + GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, + Descriptor.READABLE, + 'My Description', + ) manufacturer_name_characteristic = Characteristic( GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, "Fitbit", - [descriptor] + [descriptor], + ) + device_info_service = Service( + GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic] ) - device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [ - manufacturer_name_characteristic - ]) server_device.add_service(device_info_service) # Connect the client to the server @@ -109,6 +113,7 @@ async def main(): await asyncio.get_running_loop().create_future() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_gatt_server.py b/examples/run_gatt_server.py index 099d589a..463a7074 100644 --- a/examples/run_gatt_server.py +++ b/examples/run_gatt_server.py @@ -22,10 +22,7 @@ import logging from bumble.device import Device, Connection from bumble.transport import open_transport_or_link -from bumble.att import ( - ATT_Error, - ATT_INSUFFICIENT_ENCRYPTION_ERROR -) +from bumble.att import ATT_Error, ATT_INSUFFICIENT_ENCRYPTION_ERROR from bumble.gatt import ( Service, Characteristic, @@ -33,7 +30,7 @@ from bumble.gatt import ( Descriptor, GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, - GATT_DEVICE_INFORMATION_SERVICE + GATT_DEVICE_INFORMATION_SERVICE, ) @@ -76,7 +73,9 @@ def my_custom_write_with_error(connection, value): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: - print('Usage: run_gatt_server.py []') + print( + 'Usage: run_gatt_server.py []' + ) print('example: run_gatt_server.py device1.json usb:0 E1:CA:72:48:C4:E8') return @@ -89,17 +88,21 @@ async def main(): device.listener = Listener(device) # Add a few entries to the device's GATT server - descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') + descriptor = Descriptor( + GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, + Descriptor.READABLE, + 'My Description', + ) manufacturer_name_characteristic = Characteristic( GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 'Fitbit', - [descriptor] + [descriptor], + ) + device_info_service = Service( + GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic] ) - device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [ - manufacturer_name_characteristic - ]) custom_service1 = Service( '50DB505C-8AC4-4738-8448-3B1D9CC09CC5', [ @@ -107,21 +110,23 @@ async def main(): 'D901B45B-4916-412E-ACCA-376ECB603B2C', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - CharacteristicValue(read=my_custom_read, write=my_custom_write) + CharacteristicValue(read=my_custom_read, write=my_custom_write), ), Characteristic( '552957FB-CF1F-4A31-9535-E78847E1A714', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - CharacteristicValue(read=my_custom_read_with_error, write=my_custom_write_with_error) + CharacteristicValue( + read=my_custom_read_with_error, write=my_custom_write_with_error + ), ), Characteristic( '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', Characteristic.READ | Characteristic.NOTIFY, Characteristic.READABLE, - 'hello' - ) - ] + 'hello', + ), + ], ) device.add_services([device_info_service, custom_service1]) @@ -142,6 +147,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_hfp_gateway.py b/examples/run_hfp_gateway.py index 69db0c1a..90656ba3 100644 --- a/examples/run_hfp_gateway.py +++ b/examples/run_hfp_gateway.py @@ -31,12 +31,9 @@ from bumble.sdp import ( ServiceAttribute, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID -) -from bumble.hci import ( - BT_HANDSFREE_SERVICE, - BT_RFCOMM_PROTOCOL_ID + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) +from bumble.hci import BT_HANDSFREE_SERVICE, BT_RFCOMM_PROTOCOL_ID from bumble.hfp import HfpProtocol @@ -52,8 +49,8 @@ async def list_rfcomm_channels(device, connection): [ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID - ] + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + ], ) print(color('==================================', 'blue')) print(color('Handsfree Services:', 'yellow')) @@ -61,40 +58,59 @@ async def list_rfcomm_channels(device, connection): for attribute_list in search_result: # Look for the RFCOMM Channel number protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID + attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID ) if protocol_descriptor_list: for protocol_descriptor in protocol_descriptor_list.value: if len(protocol_descriptor.value) >= 2: if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID: print(color('SERVICE:', 'green')) - print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value) + print( + color(' RFCOMM Channel:', 'cyan'), + protocol_descriptor.value[1].value, + ) rfcomm_channels.append(protocol_descriptor.value[1].value) # List profiles - bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + bluetooth_profile_descriptor_list = ( + ServiceAttribute.find_attribute_in_list( + attribute_list, + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + ) ) if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list.value: - if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: - bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value + if ( + bluetooth_profile_descriptor_list.value[0].type + == DataElement.SEQUENCE + ): + bluetooth_profile_descriptors = ( + bluetooth_profile_descriptor_list.value + ) else: # Sometimes, instead of a list of lists, we just find a list. Fix that - bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] + bluetooth_profile_descriptors = [ + bluetooth_profile_descriptor_list + ] print(color(' Profiles:', 'green')) - for bluetooth_profile_descriptor in bluetooth_profile_descriptors: - version_major = bluetooth_profile_descriptor.value[1].value >> 8 - version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF - print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') + for ( + bluetooth_profile_descriptor + ) in bluetooth_profile_descriptors: + version_major = ( + bluetooth_profile_descriptor.value[1].value >> 8 + ) + version_minor = ( + bluetooth_profile_descriptor.value[1].value + & 0xFF + ) + print( + f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}' + ) # List service classes service_class_id_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID + attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID ) if service_class_id_list: if service_class_id_list.value: @@ -109,9 +125,15 @@ async def list_rfcomm_channels(device, connection): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 4: - print('Usage: run_hfp_gateway.py ') - print(' specifying a channel number, or "discover" to list all RFCOMM channels') - print('example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8') + print( + 'Usage: run_hfp_gateway.py ' + ) + print( + ' specifying a channel number, or "discover" to list all RFCOMM channels' + ) + print( + 'example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8' + ) return print('<<< connecting to HCI...') @@ -173,7 +195,9 @@ async def main(): protocol.send_response_line('+BRSF: 30') protocol.send_response_line('OK') elif line.startswith('AT+CIND=?'): - protocol.send_response_line('+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))') + protocol.send_response_line( + '+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))' + ) protocol.send_response_line('OK') elif line.startswith('AT+CIND?'): protocol.send_response_line('+CIND: 0,0,1,4,1,5,0') @@ -193,7 +217,9 @@ async def main(): elif line.startswith('AT+BIA='): protocol.send_response_line('OK') elif line.startswith('AT+BVRA='): - protocol.send_response_line('+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"') + protocol.send_response_line( + '+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"' + ) elif line.startswith('AT+XEVENT='): protocol.send_response_line('OK') elif line.startswith('AT+XAPL='): @@ -204,6 +230,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_hfp_handsfree.py b/examples/run_hfp_handsfree.py index cf7a0535..85ba3ded 100644 --- a/examples/run_hfp_handsfree.py +++ b/examples/run_hfp_handsfree.py @@ -32,13 +32,13 @@ from bumble.sdp import ( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) from bumble.core import ( BT_GENERIC_AUDIO_SERVICE, BT_HANDSFREE_SERVICE, BT_L2CAP_PROTOCOL_ID, - BT_RFCOMM_PROTOCOL_ID + BT_RFCOMM_PROTOCOL_ID, ) from bumble.hfp import HfpProtocol @@ -49,36 +49,44 @@ def make_sdp_records(rfcomm_channel): 0x00010001: [ ServiceAttribute( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(0x00010001) + DataElement.unsigned_integer_32(0x00010001), ), ServiceAttribute( SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - DataElement.sequence([ - DataElement.uuid(BT_HANDSFREE_SERVICE), - DataElement.uuid(BT_GENERIC_AUDIO_SERVICE) - ]) + DataElement.sequence( + [ + DataElement.uuid(BT_HANDSFREE_SERVICE), + DataElement.uuid(BT_GENERIC_AUDIO_SERVICE), + ] + ), ), ServiceAttribute( SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID) - ]), - DataElement.sequence([ - DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), - DataElement.unsigned_integer_8(rfcomm_channel) - ]) - ]) + DataElement.sequence( + [ + DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), + DataElement.sequence( + [ + DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), + DataElement.unsigned_integer_8(rfcomm_channel), + ] + ), + ] + ), ), ServiceAttribute( SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_HANDSFREE_SERVICE), - DataElement.unsigned_integer_16(0x0105) - ]) - ]) - ) + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_HANDSFREE_SERVICE), + DataElement.unsigned_integer_16(0x0105), + ] + ) + ] + ), + ), ] } @@ -103,6 +111,7 @@ class UiServer: except websockets.exceptions.ConnectionClosedOK: pass + await websockets.serve(serve, 'localhost', 8989) @@ -111,7 +120,7 @@ async def protocol_loop(protocol): await protocol.initialize_service() while True: - await(protocol.next_line()) + await (protocol.next_line()) # ----------------------------------------------------------------------------- @@ -160,6 +169,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_notifier.py b/examples/run_notifier.py index 15173cbd..2772908d 100644 --- a/examples/run_notifier.py +++ b/examples/run_notifier.py @@ -23,10 +23,7 @@ import logging from bumble.device import Device, Connection from bumble.transport import open_transport_or_link -from bumble.gatt import ( - Service, - Characteristic -) +from bumble.gatt import Service, Characteristic # ----------------------------------------------------------------------------- @@ -41,7 +38,9 @@ class Listener(Device.Listener, Connection.Listener): def on_disconnection(self, reason): print(f'### Disconnected, reason={reason}') - def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled): + def on_characteristic_subscription( + self, connection, characteristic, notify_enabled, indicate_enabled + ): print( f'$$$ Characteristic subscription for handle {characteristic.handle} from {connection}: ' f'notify {"enabled" if notify_enabled else "disabled"}, ' @@ -55,6 +54,7 @@ class Listener(Device.Listener, Connection.Listener): def on_my_characteristic_subscription(peer, enabled): print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}') + # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: @@ -75,24 +75,24 @@ async def main(): '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', Characteristic.READ | Characteristic.NOTIFY, Characteristic.READABLE, - bytes([0x40]) + bytes([0x40]), ) characteristic2 = Characteristic( '8EBDEBAE-0017-418E-8D3B-3A3809492165', Characteristic.READ | Characteristic.INDICATE, Characteristic.READABLE, - bytes([0x41]) + bytes([0x41]), ) characteristic3 = Characteristic( '8EBDEBAE-0017-418E-8D3B-3A3809492165', Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE, Characteristic.READABLE, - bytes([0x42]) + bytes([0x42]), ) characteristic3.on('subscription', on_my_characteristic_subscription) custom_service = Service( '50DB505C-8AC4-4738-8448-3B1D9CC09CC5', - [characteristic1, characteristic2, characteristic3] + [characteristic1, characteristic2, characteristic3], ) device.add_services([custom_service]) @@ -123,5 +123,5 @@ async def main(): # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_rfcomm_client.py b/examples/run_rfcomm_client.py index 76586c3e..79d35b73 100644 --- a/examples/run_rfcomm_client.py +++ b/examples/run_rfcomm_client.py @@ -31,7 +31,7 @@ from bumble.sdp import ( ServiceAttribute, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID @@ -48,47 +48,66 @@ async def list_rfcomm_channels(device, connection): [ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID - ] + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + ], ) print(color('==================================', 'blue')) print(color('RFCOMM Services:', 'yellow')) for attribute_list in search_result: # Look for the RFCOMM Channel number protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID + attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID ) if protocol_descriptor_list: for protocol_descriptor in protocol_descriptor_list.value: if len(protocol_descriptor.value) >= 2: if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID: print(color('SERVICE:', 'green')) - print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value) + print( + color(' RFCOMM Channel:', 'cyan'), + protocol_descriptor.value[1].value, + ) # List profiles - bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID + bluetooth_profile_descriptor_list = ( + ServiceAttribute.find_attribute_in_list( + attribute_list, + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + ) ) if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list.value: - if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: - bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value + if ( + bluetooth_profile_descriptor_list.value[0].type + == DataElement.SEQUENCE + ): + bluetooth_profile_descriptors = ( + bluetooth_profile_descriptor_list.value + ) else: # Sometimes, instead of a list of lists, we just find a list. Fix that - bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] + bluetooth_profile_descriptors = [ + bluetooth_profile_descriptor_list + ] print(color(' Profiles:', 'green')) - for bluetooth_profile_descriptor in bluetooth_profile_descriptors: - version_major = bluetooth_profile_descriptor.value[1].value >> 8 - version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF - print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') + for ( + bluetooth_profile_descriptor + ) in bluetooth_profile_descriptors: + version_major = ( + bluetooth_profile_descriptor.value[1].value >> 8 + ) + version_minor = ( + bluetooth_profile_descriptor.value[1].value + & 0xFF + ) + print( + f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}' + ) # List service classes service_class_id_list = ServiceAttribute.find_attribute_in_list( - attribute_list, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID + attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID ) if service_class_id_list: if service_class_id_list.value: @@ -138,9 +157,15 @@ async def tcp_server(tcp_port, rfcomm_session): # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 5: - print('Usage: run_rfcomm_client.py |discover [tcp-port]') - print(' specifying a channel number, or "discover" to list all RFCOMM channels') - print('example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8') + print( + 'Usage: run_rfcomm_client.py |discover [tcp-port]' + ) + print( + ' specifying a channel number, or "discover" to list all RFCOMM channels' + ) + print( + 'example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8' + ) return print('<<< connecting to HCI...') @@ -197,6 +222,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_rfcomm_server.py b/examples/run_rfcomm_server.py index a239cebd..8f77d9f4 100644 --- a/examples/run_rfcomm_server.py +++ b/examples/run_rfcomm_server.py @@ -31,7 +31,7 @@ from bumble.sdp import ( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, ) from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID @@ -40,22 +40,34 @@ from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID def sdp_records(channel): return { 0x00010001: [ - ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)), - ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) - ])), - ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')) - ])), - ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ - DataElement.sequence([ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID) - ]), - DataElement.sequence([ - DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), - DataElement.unsigned_integer_8(channel) - ]) - ])) + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(0x00010001), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), + DataElement.sequence( + [ + DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), + DataElement.unsigned_integer_8(channel), + ] + ), + ] + ), + ), ] } @@ -113,6 +125,7 @@ async def main(): await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/examples/run_scanner.py b/examples/run_scanner.py index 719e58ed..3e47a4e1 100644 --- a/examples/run_scanner.py +++ b/examples/run_scanner.py @@ -35,13 +35,15 @@ async def main(): print('<<< connecting to HCI...') async with await open_transport_or_link(sys.argv[1]) as (hci_source, hci_sink): print('<<< connected') - filter_duplicates = (len(sys.argv) == 3 and sys.argv[2] == 'filter') + filter_duplicates = len(sys.argv) == 3 and sys.argv[2] == 'filter' device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) @device.on('advertisement') def _(advertisement): - address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[advertisement.address.address_type] + address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[ + advertisement.address.address_type + ] address_color = 'yellow' if advertisement.is_connectable else 'red' address_qualifier = '' if address_type_string.startswith('P'): @@ -57,13 +59,16 @@ async def main(): type_color = 'white' separator = '\n ' - print(f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{advertisement.rssi}{separator}{advertisement.data.to_string(separator)}') + print( + f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{advertisement.rssi}{separator}{advertisement.data.to_string(separator)}' + ) await device.power_on() await device.start_scanning(filter_duplicates=filter_duplicates) await hci_source.wait_for_termination() + # ----------------------------------------------------------------------------- -logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) diff --git a/setup.py b/setup.py index 16ecaca6..1380bc95 100644 --- a/setup.py +++ b/setup.py @@ -13,4 +13,5 @@ # limitations under the License. from setuptools import setup + setup() diff --git a/tasks.py b/tasks.py index c361e94d..14800f08 100644 --- a/tasks.py +++ b/tasks.py @@ -27,6 +27,7 @@ ns = Collection() build_tasks = Collection() ns.add_collection(build_tasks, name="build") + @task def build(ctx, install=False): if install: @@ -34,18 +35,23 @@ def build(ctx, install=False): ctx.run("python -m build") + build_tasks.add_task(build, default=True) + @task def release_build(ctx): build(ctx, install=True) + build_tasks.add_task(release_build, name="release") + @task def mkdocs(ctx): ctx.run("mkdocs build -f docs/mkdocs/mkdocs.yml") + build_tasks.add_task(mkdocs, name="mkdocs") # Testing @@ -70,10 +76,13 @@ def test(ctx, filter=None, junit=False, install=False, html=False, verbose=0): args += f" -{'v' * verbose}" ctx.run(f"python -m pytest {os.path.join(ROOT_DIR, 'tests')} {args}") + test_tasks.add_task(test, default=True) + @task def release_test(ctx): test(ctx, install=True) + test_tasks.add_task(release_test, name="release") diff --git a/tests/a2dp_test.py b/tests/a2dp_test.py index e09694db..e4995315 100644 --- a/tests/a2dp_test.py +++ b/tests/a2dp_test.py @@ -35,7 +35,7 @@ from bumble.avdtp import ( MediaPacket, AVDTP_AUDIO_MEDIA_TYPE, AVDTP_TSEP_SNK, - A2DP_SBC_CODEC_TYPE + A2DP_SBC_CODEC_TYPE, ) from bumble.a2dp import ( SbcMediaCodecInformation, @@ -44,7 +44,7 @@ from bumble.a2dp import ( SBC_STEREO_CHANNEL_MODE, SBC_JOINT_STEREO_CHANNEL_MODE, SBC_LOUDNESS_ALLOCATION_METHOD, - SBC_SNR_ALLOCATION_METHOD + SBC_SNR_ALLOCATION_METHOD, ) # ----------------------------------------------------------------------------- @@ -60,18 +60,18 @@ class TwoDevices: self.link = LocalLink() self.controllers = [ - Controller('C1', link = self.link), - Controller('C2', link = self.link) + Controller('C1', link=self.link), + Controller('C2', link=self.link), ] self.devices = [ Device( - address = 'F0:F1:F2:F3:F4:F5', - host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) + address='F0:F1:F2:F3:F4:F5', + host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address = 'F5:F4:F3:F2:F1:F0', - host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) - ) + address='F5:F4:F3:F2:F1:F0', + host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), + ), ] self.paired = [None, None] @@ -87,8 +87,12 @@ async def test_self_connection(): two_devices = TwoDevices() # Attach listeners - two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) - two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) + two_devices.devices[0].on( + 'connection', lambda connection: two_devices.on_connection(0, connection) + ) + two_devices.devices[1].on( + 'connection', lambda connection: two_devices.on_connection(1, connection) + ) # Start await two_devices.devices[0].power_on() @@ -98,46 +102,49 @@ async def test_self_connection(): await two_devices.devices[0].connect(two_devices.devices[1].random_address) # Check the post conditions - assert(two_devices.connections[0] is not None) - assert(two_devices.connections[1] is not None) + assert two_devices.connections[0] is not None + assert two_devices.connections[1] is not None # ----------------------------------------------------------------------------- def source_codec_capabilities(): return MediaCodecCapabilities( - media_type = AVDTP_AUDIO_MEDIA_TYPE, - media_codec_type = A2DP_SBC_CODEC_TYPE, - media_codec_information = SbcMediaCodecInformation.from_discrete_values( - sampling_frequency = 44100, - channel_mode = SBC_JOINT_STEREO_CHANNEL_MODE, - block_length = 16, - subbands = 8, - allocation_method = SBC_LOUDNESS_ALLOCATION_METHOD, - minimum_bitpool_value = 2, - maximum_bitpool_value = 53 - ) + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_discrete_values( + sampling_frequency=44100, + channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE, + block_length=16, + subbands=8, + allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), ) # ----------------------------------------------------------------------------- def sink_codec_capabilities(): return MediaCodecCapabilities( - media_type = AVDTP_AUDIO_MEDIA_TYPE, - media_codec_type = A2DP_SBC_CODEC_TYPE, - media_codec_information = SbcMediaCodecInformation.from_lists( - sampling_frequencies = [48000, 44100, 32000, 16000], - channel_modes = [ + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_lists( + sampling_frequencies=[48000, 44100, 32000, 16000], + channel_modes=[ SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE, - SBC_JOINT_STEREO_CHANNEL_MODE + SBC_JOINT_STEREO_CHANNEL_MODE, ], - block_lengths = [4, 8, 12, 16], - subbands = [4, 8], - allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD], - minimum_bitpool_value = 2, - maximum_bitpool_value = 53 - ) + block_lengths=[4, 8, 12, 16], + subbands=[4, 8], + allocation_methods=[ + SBC_LOUDNESS_ALLOCATION_METHOD, + SBC_SNR_ALLOCATION_METHOD, + ], + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), ) @@ -164,21 +171,25 @@ async def test_source_sink_1(): listener = Listener(Listener.create_registrar(two_devices.devices[1])) listener.on('connection', on_avdtp_connection) - connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address) + connection = await two_devices.devices[0].connect( + two_devices.devices[1].random_address + ) client = await Protocol.connect(connection) endpoints = await client.discover_remote_endpoints() - assert(len(endpoints) == 1) + assert len(endpoints) == 1 remote_sink = list(endpoints)[0] - assert(remote_sink.in_use == 0) - assert(remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE) - assert(remote_sink.tsep == AVDTP_TSEP_SNK) + assert remote_sink.in_use == 0 + assert remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE + assert remote_sink.tsep == AVDTP_TSEP_SNK async def generate_packets(packet_count): sequence_number = 0 timestamp = 0 for i in range(packet_count): payload = bytes([sequence_number % 256]) - packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload) + packet = MediaPacket( + 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload + ) packet.timestamp_seconds = timestamp / 44100 timestamp += 10 sequence_number += 1 @@ -192,50 +203,49 @@ async def test_source_sink_1(): source = client.add_source(source_codec_capabilities(), pump) stream = await client.create_stream(source, remote_sink) await stream.start() - assert(stream.state == AVDTP_STREAMING_STATE) - assert(stream.local_endpoint.in_use == 1) - assert(stream.rtp_channel is not None) - assert(sink.in_use == 1) - assert(sink.stream is not None) - assert(sink.stream.state == AVDTP_STREAMING_STATE) + assert stream.state == AVDTP_STREAMING_STATE + assert stream.local_endpoint.in_use == 1 + assert stream.rtp_channel is not None + assert sink.in_use == 1 + assert sink.stream is not None + assert sink.stream.state == AVDTP_STREAMING_STATE await rtp_packets_fully_received await stream.close() - assert(stream.rtp_channel is None) - assert(source.in_use == 0) - assert(source.stream.state == AVDTP_IDLE_STATE) - assert(sink.in_use == 0) - assert(sink.stream.state == AVDTP_IDLE_STATE) + assert stream.rtp_channel is None + assert source.in_use == 0 + assert source.stream.state == AVDTP_IDLE_STATE + assert sink.in_use == 0 + assert sink.stream.state == AVDTP_IDLE_STATE # Send packets manually rtp_packets_fully_received = asyncio.get_running_loop().create_future() rtp_packets_expected = 3 rtp_packets = [] source_packets = [ - MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i])) - for i in range(3) + MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i])) for i in range(3) ] source = client.add_source(source_codec_capabilities(), None) stream = await client.create_stream(source, remote_sink) await stream.start() - assert(stream.state == AVDTP_STREAMING_STATE) - assert(stream.local_endpoint.in_use == 1) - assert(stream.rtp_channel is not None) - assert(sink.in_use == 1) - assert(sink.stream is not None) - assert(sink.stream.state == AVDTP_STREAMING_STATE) + assert stream.state == AVDTP_STREAMING_STATE + assert stream.local_endpoint.in_use == 1 + assert stream.rtp_channel is not None + assert sink.in_use == 1 + assert sink.stream is not None + assert sink.stream.state == AVDTP_STREAMING_STATE stream.send_media_packet(source_packets[0]) stream.send_media_packet(source_packets[1]) stream.send_media_packet(source_packets[2]) await stream.close() - assert(stream.rtp_channel is None) - assert(len(rtp_packets) == 3) - assert(source.in_use == 0) - assert(source.stream.state == AVDTP_IDLE_STATE) - assert(sink.in_use == 0) - assert(sink.stream.state == AVDTP_IDLE_STATE) + assert stream.rtp_channel is None + assert len(rtp_packets) == 3 + assert source.in_use == 0 + assert source.stream.state == AVDTP_IDLE_STATE + assert sink.in_use == 0 + assert sink.stream.state == AVDTP_IDLE_STATE # ----------------------------------------------------------------------------- @@ -246,5 +256,5 @@ async def run_test_self(): # ----------------------------------------------------------------------------- if __name__ == '__main__': - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(run_test_self()) diff --git a/tests/avdtp_test.py b/tests/avdtp_test.py index d14e4faf..1ca5254d 100644 --- a/tests/avdtp_test.py +++ b/tests/avdtp_test.py @@ -28,7 +28,7 @@ from bumble.avdtp import ( Set_Configuration_Command, Set_Configuration_Response, ServiceCapabilities, - MediaCodecCapabilities + MediaCodecCapabilities, ) @@ -37,24 +37,28 @@ def test_messages(): capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), MediaCodecCapabilities( - media_type = AVDTP_AUDIO_MEDIA_TYPE, - media_codec_type = A2DP_SBC_CODEC_TYPE, - media_codec_information = bytes.fromhex('211502fa') + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=bytes.fromhex('211502fa'), ), - ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY) + ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY), ] message = Get_Capabilities_Response(capabilities) - parsed = Message.create(AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload) - assert(message.payload == parsed.payload) + parsed = Message.create( + AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload + ) + assert message.payload == parsed.payload message = Set_Configuration_Command(3, 4, capabilities) parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload) - assert(message.payload == parsed.payload) + assert message.payload == parsed.payload # ----------------------------------------------------------------------------- def test_rtp(): - packet = bytes.fromhex('8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093') + packet = bytes.fromhex( + '8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093' + ) media_packet = MediaPacket.from_bytes(packet) print(media_packet) @@ -62,4 +66,4 @@ def test_rtp(): # ----------------------------------------------------------------------------- if __name__ == '__main__': test_messages() - test_rtp() \ No newline at end of file + test_rtp() diff --git a/tests/core_test.py b/tests/core_test.py index 226d0bfb..ba4ca5d3 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -22,32 +22,35 @@ def test_ad_data(): data = bytes([2, AdvertisingData.TX_POWER_LEVEL, 123]) ad = AdvertisingData.from_bytes(data) ad_bytes = bytes(ad) - assert(data == ad_bytes) - assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None) - assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])) - assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []) - assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123])]) + assert data == ad_bytes + assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None + assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]) + assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [] + assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [ + bytes([123]) + ] data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234]) ad.append(data2) ad_bytes = bytes(ad) - assert(ad_bytes == data + data2) - assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None) - assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])) - assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []) - assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123]), bytes([234])]) + assert ad_bytes == data + data2 + assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None + assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]) + assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [] + assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [ + bytes([123]), + bytes([234]), + ] # ----------------------------------------------------------------------------- def test_get_dict_key_by_value(): - dictionary = { - "A": 1, - "B": 2 - } + dictionary = {"A": 1, "B": 2} assert get_dict_key_by_value(dictionary, 1) == "A" assert get_dict_key_by_value(dictionary, 2) == "B" assert get_dict_key_by_value(dictionary, 3) is None + # ----------------------------------------------------------------------------- if __name__ == '__main__': - test_ad_data() \ No newline at end of file + test_ad_data() diff --git a/tests/device_test.py b/tests/device_test.py index dd23c886..123df29e 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -25,10 +25,23 @@ from bumble.core import BT_BR_EDR_TRANSPORT from bumble.device import Connection, Device from bumble.host import Host from bumble.hci import ( - HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS, - Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet + HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, + HCI_COMMAND_STATUS_PENDING, + HCI_CREATE_CONNECTION_COMMAND, + HCI_SUCCESS, + Address, + HCI_Command_Complete_Event, + HCI_Command_Status_Event, + HCI_Connection_Complete_Event, + HCI_Connection_Request_Event, + HCI_Packet, +) +from bumble.gatt import ( + GATT_GENERIC_ACCESS_SERVICE, + GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, + GATT_DEVICE_NAME_CHARACTERISTIC, + GATT_APPEARANCE_CHARACTERISTIC, ) -from bumble.gatt import GATT_GENERIC_ACCESS_SERVICE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_DEVICE_NAME_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC # ----------------------------------------------------------------------------- # Logging @@ -59,70 +72,90 @@ async def test_device_connect_parallel(): d2.classic_enabled = True # set public addresses - d0.public_address = Address('F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS) - d1.public_address = Address('F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS) - d2.public_address = Address('F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS) + d0.public_address = Address( + 'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS + ) + d1.public_address = Address( + 'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS + ) + d2.public_address = Address( + 'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS + ) def d0_flow(): packet = HCI_Packet.from_bytes((yield)) assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' assert packet.bd_addr == d1.public_address - d0.host.on_hci_packet(HCI_Command_Status_Event( - status = HCI_COMMAND_STATUS_PENDING, - num_hci_command_packets = 1, - command_opcode = HCI_CREATE_CONNECTION_COMMAND - )) + d0.host.on_hci_packet( + HCI_Command_Status_Event( + status=HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets=1, + command_opcode=HCI_CREATE_CONNECTION_COMMAND, + ) + ) - d1.host.on_hci_packet(HCI_Connection_Request_Event( - bd_addr = d0.public_address, - class_of_device = 0, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE - )) + d1.host.on_hci_packet( + HCI_Connection_Request_Event( + bd_addr=d0.public_address, + class_of_device=0, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + ) + ) packet = HCI_Packet.from_bytes((yield)) assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' assert packet.bd_addr == d2.public_address - d0.host.on_hci_packet(HCI_Command_Status_Event( - status = HCI_COMMAND_STATUS_PENDING, - num_hci_command_packets = 1, - command_opcode = HCI_CREATE_CONNECTION_COMMAND - )) + d0.host.on_hci_packet( + HCI_Command_Status_Event( + status=HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets=1, + command_opcode=HCI_CREATE_CONNECTION_COMMAND, + ) + ) - d2.host.on_hci_packet(HCI_Connection_Request_Event( - bd_addr = d0.public_address, - class_of_device = 0, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE - )) + d2.host.on_hci_packet( + HCI_Connection_Request_Event( + bd_addr=d0.public_address, + class_of_device=0, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + ) + ) assert (yield) == None - + def d1_flow(): packet = HCI_Packet.from_bytes((yield)) assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' - d1.host.on_hci_packet(HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, - return_parameters = b"\x00" - )) + d1.host.on_hci_packet( + HCI_Command_Complete_Event( + num_hci_command_packets=1, + command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, + return_parameters=b"\x00", + ) + ) - d1.host.on_hci_packet(HCI_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x100, - bd_addr = d0.public_address, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, - encryption_enabled = True, - )) + d1.host.on_hci_packet( + HCI_Connection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=0x100, + bd_addr=d0.public_address, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled=True, + ) + ) - d0.host.on_hci_packet(HCI_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x100, - bd_addr = d1.public_address, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, - encryption_enabled = True, - )) + d0.host.on_hci_packet( + HCI_Connection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=0x100, + bd_addr=d1.public_address, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled=True, + ) + ) assert (yield) == None @@ -130,27 +163,33 @@ async def test_device_connect_parallel(): packet = HCI_Packet.from_bytes((yield)) assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' - d2.host.on_hci_packet(HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, - return_parameters = b"\x00" - )) + d2.host.on_hci_packet( + HCI_Command_Complete_Event( + num_hci_command_packets=1, + command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, + return_parameters=b"\x00", + ) + ) - d2.host.on_hci_packet(HCI_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x101, - bd_addr = d0.public_address, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, - encryption_enabled = True, - )) + d2.host.on_hci_packet( + HCI_Connection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=0x101, + bd_addr=d0.public_address, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled=True, + ) + ) - d0.host.on_hci_packet(HCI_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x101, - bd_addr = d2.public_address, - link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, - encryption_enabled = True, - )) + d0.host.on_hci_packet( + HCI_Connection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=0x101, + bd_addr=d2.public_address, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled=True, + ) + ) assert (yield) == None @@ -158,13 +197,19 @@ async def test_device_connect_parallel(): d1.host.set_packet_sink(Sink(d1_flow())) d2.host.set_packet_sink(Sink(d2_flow())) - [c01, c02, a10, a20, a01] = await asyncio.gather(*[ - asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)), - asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)), - asyncio.create_task(d1.accept(peer_address=d0.public_address)), - asyncio.create_task(d2.accept()), - asyncio.create_task(d0.accept(peer_address=d1.public_address)), - ]) + [c01, c02, a10, a20, a01] = await asyncio.gather( + *[ + asyncio.create_task( + d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT) + ), + asyncio.create_task( + d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT) + ), + asyncio.create_task(d1.accept(peer_address=d0.public_address)), + asyncio.create_task(d2.accept()), + asyncio.create_task(d0.accept(peer_address=d1.public_address)), + ] + ) assert type(c01) == Connection assert type(c02) == Connection @@ -205,5 +250,5 @@ def test_gatt_services_without_gas(): # ----------------------------------------------------------------------------- if __name__ == '__main__': - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(run_test_device()) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 9630fde1..24736237 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -36,7 +36,7 @@ from bumble.gatt import ( UTF8CharacteristicAdapter, Service, Characteristic, - CharacteristicValue + CharacteristicValue, ) from bumble.transport import AsyncPipeSink from bumble.core import UUID @@ -45,7 +45,7 @@ from bumble.att import ( ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_PDU, ATT_Error_Response, - ATT_Read_By_Group_Type_Request + ATT_Read_By_Group_Type_Request, ) @@ -72,20 +72,20 @@ def test_UUID(): assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' u1 = UUID.from_16_bits(0x1234) - b1 = u1.to_bytes(force_128 = True) + b1 = u1.to_bytes(force_128=True) u2 = UUID.from_bytes(b1) assert u1 == u2 - u3 = UUID.from_16_bits(0x180a) + u3 = UUID.from_16_bits(0x180A) assert str(u3) == 'UUID-16:180A (Device Information)' # ----------------------------------------------------------------------------- def test_ATT_Error_Response(): pdu = ATT_Error_Response( - request_opcode_in_error = ATT_EXCHANGE_MTU_REQUEST, - attribute_handle_in_error = 0x0000, - error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR + request_opcode_in_error=ATT_EXCHANGE_MTU_REQUEST, + attribute_handle_in_error=0x0000, + error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) basic_check(pdu) @@ -93,9 +93,9 @@ def test_ATT_Error_Response(): # ----------------------------------------------------------------------------- def test_ATT_Read_By_Group_Type_Request(): pdu = ATT_Read_By_Group_Type_Request( - starting_handle = 0x0001, - ending_handle = 0xFFFF, - attribute_group_type = UUID.from_16_bits(0x2800) + starting_handle=0x0001, + ending_handle=0xFFFF, + attribute_group_type=UUID.from_16_bits(0x2800), ) basic_check(pdu) @@ -110,7 +110,12 @@ async def test_characteristic_encoding(): def decode_value(self, value_bytes): return value_bytes[0] - c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123) + c = Foo( + GATT_BATTERY_LEVEL_CHARACTERISTIC, + Characteristic.READ, + Characteristic.READABLE, + 123, + ) x = c.read_value(None) assert x == bytes([123]) c.write_value(None, bytes([122])) @@ -123,7 +128,7 @@ async def test_characteristic_encoding(): characteristic.handle, characteristic.end_group_handle, characteristic.uuid, - characteristic.properties + characteristic.properties, ) def encode_value(self, value): @@ -138,13 +143,10 @@ async def test_characteristic_encoding(): 'FDB159DB-036C-49E3-B3DB-6325AC750806', Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY, Characteristic.READABLE | Characteristic.WRITEABLE, - bytes([123]) + bytes([123]), ) - service = Service( - '3A657F47-D34F-46B3-B1EC-698E29B6B829', - [characteristic] - ) + service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic]) server.add_service(service) await client.power_on() @@ -237,7 +239,7 @@ async def test_attribute_getters(): characteristic_uuid, Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY, Characteristic.READABLE | Characteristic.WRITEABLE, - bytes([123]) + bytes([123]), ) service_uuid = UUID('3A657F47-D34F-46B3-B1EC-698E29B6B829') @@ -247,22 +249,43 @@ async def test_attribute_getters(): service_attr = server.gatt_server.get_service_attribute(service_uuid) assert service_attr - (char_decl_attr, char_value_attr) = server.gatt_server.get_characteristic_attributes(service_uuid, characteristic_uuid) + ( + char_decl_attr, + char_value_attr, + ) = server.gatt_server.get_characteristic_attributes( + service_uuid, characteristic_uuid + ) assert char_decl_attr and char_value_attr - desc_attr = server.gatt_server.get_descriptor_attribute(service_uuid, characteristic_uuid, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) + desc_attr = server.gatt_server.get_descriptor_attribute( + service_uuid, + characteristic_uuid, + GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, + ) assert desc_attr # assert all handles are in expected order - assert service_attr.handle < char_decl_attr.handle < char_value_attr.handle < desc_attr.handle == service_attr.end_group_handle + assert ( + service_attr.handle + < char_decl_attr.handle + < char_value_attr.handle + < desc_attr.handle + == service_attr.end_group_handle + ) # assert characteristic declarations attribute is followed by characteristic value attribute assert char_decl_attr.handle + 1 == char_value_attr.handle + # ----------------------------------------------------------------------------- def test_CharacteristicAdapter(): # Check that the CharacteristicAdapter base class is transparent v = bytes([1, 2, 3]) - c = Characteristic(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, v) + c = Characteristic( + GATT_BATTERY_LEVEL_CHARACTERISTIC, + Characteristic.READ, + Characteristic.READABLE, + v, + ) a = CharacteristicAdapter(c) value = a.read_value(None) @@ -273,7 +296,9 @@ def test_CharacteristicAdapter(): assert c.value == v # Simple delegated adapter - a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))) + a = DelegatedCharacteristicAdapter( + c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) + ) value = a.read_value(None) assert value == bytes(reversed(v)) @@ -342,7 +367,9 @@ def test_CharacteristicValue(): assert x == b result = [] - c = CharacteristicValue(write=lambda connection, value: result.append((connection, value))) + c = CharacteristicValue( + write=lambda connection, value: result.append((connection, value)) + ) z = object() c.write(z, b) assert result == [(z, b)] @@ -355,23 +382,23 @@ class LinkedDevices: self.link = LocalLink() self.controllers = [ - Controller('C1', link = self.link), - Controller('C2', link = self.link), - Controller('C3', link = self.link) + Controller('C1', link=self.link), + Controller('C2', link=self.link), + Controller('C3', link=self.link), ] self.devices = [ Device( - address = 'F0:F1:F2:F3:F4:F5', - host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) + address='F0:F1:F2:F3:F4:F5', + host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address = 'F1:F2:F3:F4:F5:F6', - host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) + address='F1:F2:F3:F4:F5:F6', + host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), ), Device( - address = 'F2:F3:F4:F5:F6:F7', - host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2])) - ) + address='F2:F3:F4:F5:F6:F7', + host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])), + ), ] self.paired = [None, None, None] @@ -392,7 +419,7 @@ async def test_read_write(): characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', Characteristic.READ | Characteristic.WRITE, - Characteristic.READABLE | Characteristic.WRITEABLE + Characteristic.READABLE | Characteristic.WRITEABLE, ) def on_characteristic1_write(connection, value): @@ -410,15 +437,13 @@ async def test_read_write(): '66DE9057-C848-4ACA-B993-D675644EBB85', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - CharacteristicValue(read=on_characteristic2_read, write=on_characteristic2_write) + CharacteristicValue( + read=on_characteristic2_read, write=on_characteristic2_write + ), ) service1 = Service( - '3A657F47-D34F-46B3-B1EC-698E29B6B829', - [ - characteristic1, - characteristic2 - ] + '3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1, characteristic2] ) server.add_services([service1]) @@ -446,7 +471,9 @@ async def test_read_write(): assert v1 == b assert type(characteristic1._last_value is tuple) assert len(characteristic1._last_value) == 2 - assert str(characteristic1._last_value[0].peer_address) == str(client.random_address) + assert str(characteristic1._last_value[0].peer_address) == str( + client.random_address + ) assert characteristic1._last_value[1] == b bb = bytes([3, 4, 5, 6]) characteristic1.value = bb @@ -457,7 +484,9 @@ async def test_read_write(): await async_barrier() assert type(characteristic2._last_value is tuple) assert len(characteristic2._last_value) == 2 - assert str(characteristic2._last_value[0].peer_address) == str(client.random_address) + assert str(characteristic2._last_value[0].peer_address) == str( + client.random_address + ) assert characteristic2._last_value[1] == b @@ -471,15 +500,10 @@ async def test_read_write2(): 'FDB159DB-036C-49E3-B3DB-6325AC750806', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - value=v + value=v, ) - service1 = Service( - '3A657F47-D34F-46B3-B1EC-698E29B6B829', - [ - characteristic1 - ] - ) + service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1]) server.add_services([service1]) await client.power_on() @@ -520,11 +544,15 @@ async def test_subscribe_notify(): 'FDB159DB-036C-49E3-B3DB-6325AC750806', Characteristic.READ | Characteristic.NOTIFY, Characteristic.READABLE, - bytes([1, 2, 3]) + bytes([1, 2, 3]), ) def on_characteristic1_subscription(connection, notify_enabled, indicate_enabled): - characteristic1._last_subscription = (connection, notify_enabled, indicate_enabled) + characteristic1._last_subscription = ( + connection, + notify_enabled, + indicate_enabled, + ) characteristic1.on('subscription', on_characteristic1_subscription) @@ -532,11 +560,15 @@ async def test_subscribe_notify(): '66DE9057-C848-4ACA-B993-D675644EBB85', Characteristic.READ | Characteristic.INDICATE, Characteristic.READABLE, - bytes([4, 5, 6]) + bytes([4, 5, 6]), ) def on_characteristic2_subscription(connection, notify_enabled, indicate_enabled): - characteristic2._last_subscription = (connection, notify_enabled, indicate_enabled) + characteristic2._last_subscription = ( + connection, + notify_enabled, + indicate_enabled, + ) characteristic2.on('subscription', on_characteristic2_subscription) @@ -544,26 +576,33 @@ async def test_subscribe_notify(): 'AB5E639C-40C1-4238-B9CB-AF41F8B806E4', Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE, Characteristic.READABLE, - bytes([7, 8, 9]) + bytes([7, 8, 9]), ) def on_characteristic3_subscription(connection, notify_enabled, indicate_enabled): - characteristic3._last_subscription = (connection, notify_enabled, indicate_enabled) + characteristic3._last_subscription = ( + connection, + notify_enabled, + indicate_enabled, + ) characteristic3.on('subscription', on_characteristic3_subscription) service1 = Service( '3A657F47-D34F-46B3-B1EC-698E29B6B829', - [ - characteristic1, - characteristic2, - characteristic3 - ] + [characteristic1, characteristic2, characteristic3], ) server.add_services([service1]) - def on_characteristic_subscription(connection, characteristic, notify_enabled, indicate_enabled): - server._last_subscription = (connection, characteristic, notify_enabled, indicate_enabled) + def on_characteristic_subscription( + connection, characteristic, notify_enabled, indicate_enabled + ): + server._last_subscription = ( + connection, + characteristic, + notify_enabled, + indicate_enabled, + ) server.on('characteristic_subscription', on_characteristic_subscription) @@ -630,17 +669,23 @@ async def test_subscribe_notify(): await peer.subscribe(c2, on_c2_update) await async_barrier() - await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2) + await server.notify_subscriber( + characteristic2._last_subscription[0], characteristic2 + ) await async_barrier() assert not c2._called - await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2) + await server.indicate_subscriber( + characteristic2._last_subscription[0], characteristic2 + ) await async_barrier() assert c2._called assert c2._last_update == characteristic2.value c2._called = False await peer.unsubscribe(c2, on_c2_update) - await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2) + await server.indicate_subscriber( + characteristic2._last_subscription[0], characteristic2 + ) await async_barrier() assert not c2._called @@ -666,7 +711,9 @@ async def test_subscribe_notify(): c3.on('update', on_c3_update) await peer.subscribe(c3, on_c3_update_2) await async_barrier() - await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3) + await server.notify_subscriber( + characteristic3._last_subscription[0], characteristic3 + ) await async_barrier() assert c3._called assert c3._last_update == characteristic3.value @@ -681,7 +728,9 @@ async def test_subscribe_notify(): await peer.subscribe(c3, on_c3_update_3, prefer_notify=False) await async_barrier() characteristic3.value = bytes([1, 2, 3]) - await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3) + await server.indicate_subscriber( + characteristic3._last_subscription[0], characteristic3 + ) await async_barrier() assert c3._called assert c3._last_update == characteristic3.value @@ -693,8 +742,12 @@ async def test_subscribe_notify(): c3._called_2 = False c3._called_3 = False await peer.unsubscribe(c3) - await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3) - await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3) + await server.notify_subscriber( + characteristic3._last_subscription[0], characteristic3 + ) + await server.indicate_subscriber( + characteristic3._last_subscription[0], characteristic3 + ) await async_barrier() assert not c3._called assert not c3._called_2 @@ -709,6 +762,7 @@ async def test_mtu_exchange(): d3.gatt_server.max_mtu = 100 d3_connections = [] + @d3.on('connection') def on_d3_connection(connection): d3_connections.append(connection) @@ -745,7 +799,12 @@ def test_char_property_to_string(): # double assert Characteristic.properties_as_string(0x03) == "BROADCAST,READ" - assert Characteristic.properties_as_string(Characteristic.BROADCAST | Characteristic.READ) == "BROADCAST,READ" + assert ( + Characteristic.properties_as_string( + Characteristic.BROADCAST | Characteristic.READ + ) + == "BROADCAST,READ" + ) # ----------------------------------------------------------------------------- @@ -754,8 +813,14 @@ def test_char_property_string_to_type(): assert Characteristic.string_to_properties("BROADCAST") == Characteristic.BROADCAST # double - assert Characteristic.string_to_properties("BROADCAST,READ") == Characteristic.BROADCAST | Characteristic.READ - assert Characteristic.string_to_properties("READ,BROADCAST") == Characteristic.BROADCAST | Characteristic.READ + assert ( + Characteristic.string_to_properties("BROADCAST,READ") + == Characteristic.BROADCAST | Characteristic.READ + ) + assert ( + Characteristic.string_to_properties("READ,BROADCAST") + == Characteristic.BROADCAST | Characteristic.READ + ) # ----------------------------------------------------------------------------- @@ -767,16 +832,15 @@ async def test_server_string(): 'FDB159DB-036C-49E3-B3DB-6325AC750806', Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY, Characteristic.READABLE | Characteristic.WRITEABLE, - bytes([123]) + bytes([123]), ) - service = Service( - '3A657F47-D34F-46B3-B1EC-698E29B6B829', - [characteristic] - ) + service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic]) server.add_service(service) - assert str(server.gatt_server) == """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access)) + assert ( + str(server.gatt_server) + == """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access)) CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ) Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ) CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ) @@ -785,6 +849,8 @@ Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829) CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY) Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY) Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)""" + ) + # ----------------------------------------------------------------------------- async def async_main(): @@ -797,7 +863,7 @@ async def async_main(): # ----------------------------------------------------------------------------- if __name__ == '__main__': - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) test_UUID() test_ATT_Error_Response() test_ATT_Read_By_Group_Type_Request() diff --git a/tests/hci_test.py b/tests/hci_test.py index a5e3a829..3d40628c 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -44,15 +44,15 @@ def test_HCI_Event(): def test_HCI_LE_Connection_Complete_Event(): address = Address('00:11:22:33:44:55') event = HCI_LE_Connection_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 1, - role = 1, - peer_address_type = 1, - peer_address = address, - connection_interval = 3, - peripheral_latency = 4, - supervision_timeout = 5, - central_clock_accuracy = 6 + status=HCI_SUCCESS, + connection_handle=1, + role=1, + peer_address_type=1, + peer_address=address, + connection_interval=3, + peripheral_latency=4, + supervision_timeout=5, + central_clock_accuracy=6, ) basic_check(event) @@ -62,11 +62,13 @@ def test_HCI_LE_Advertising_Report_Event(): address = Address('00:11:22:33:44:55/P') report = HCI_LE_Advertising_Report_Event.Report( HCI_LE_Advertising_Report_Event.Report.FIELDS, - event_type = HCI_LE_Advertising_Report_Event.ADV_IND, - address_type = Address.PUBLIC_DEVICE_ADDRESS, - address = address, - data = bytes.fromhex('0201061106ba5689a6fabfa2bd01467d6e00fbabad08160a181604659b03'), - rssi = 100 + event_type=HCI_LE_Advertising_Report_Event.ADV_IND, + address_type=Address.PUBLIC_DEVICE_ADDRESS, + address=address, + data=bytes.fromhex( + '0201061106ba5689a6fabfa2bd01467d6e00fbabad08160a181604659b03' + ), + rssi=100, ) event = HCI_LE_Advertising_Report_Event([report]) basic_check(event) @@ -75,9 +77,9 @@ def test_HCI_LE_Advertising_Report_Event(): # ----------------------------------------------------------------------------- def test_HCI_LE_Read_Remote_Features_Complete_Event(): event = HCI_LE_Read_Remote_Features_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x007, - le_features = bytes.fromhex('0011223344556677') + status=HCI_SUCCESS, + connection_handle=0x007, + le_features=bytes.fromhex('0011223344556677'), ) basic_check(event) @@ -85,11 +87,11 @@ def test_HCI_LE_Read_Remote_Features_Complete_Event(): # ----------------------------------------------------------------------------- def test_HCI_LE_Connection_Update_Complete_Event(): event = HCI_LE_Connection_Update_Complete_Event( - status = HCI_SUCCESS, - connection_handle = 0x007, - connection_interval = 10, - peripheral_latency = 3, - supervision_timeout = 5 + status=HCI_SUCCESS, + connection_handle=0x007, + connection_interval=10, + peripheral_latency=3, + supervision_timeout=5, ) basic_check(event) @@ -97,8 +99,7 @@ def test_HCI_LE_Connection_Update_Complete_Event(): # ----------------------------------------------------------------------------- def test_HCI_LE_Channel_Selection_Algorithm_Event(): event = HCI_LE_Channel_Selection_Algorithm_Event( - connection_handle = 7, - channel_selection_algorithm = 1 + connection_handle=7, channel_selection_algorithm=1 ) basic_check(event) @@ -107,29 +108,29 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event(): def test_HCI_Command_Complete_Event(): # With a serializable object event = HCI_Command_Complete_Event( - num_hci_command_packets = 34, - command_opcode = HCI_LE_READ_BUFFER_SIZE_COMMAND, - return_parameters = HCI_LE_Read_Buffer_Size_Command.create_return_parameters( - status = 0, - hc_le_acl_data_packet_length = 1234, - hc_total_num_le_acl_data_packets = 56 - ) + num_hci_command_packets=34, + command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND, + return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters( + status=0, + hc_le_acl_data_packet_length=1234, + hc_total_num_le_acl_data_packets=56, + ), ) basic_check(event) # With an arbitrary byte array event = HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = HCI_RESET_COMMAND, - return_parameters = bytes([1, 2, 3, 4]) + num_hci_command_packets=1, + command_opcode=HCI_RESET_COMMAND, + return_parameters=bytes([1, 2, 3, 4]), ) basic_check(event) # With a simple status as a 1-byte array event = HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = HCI_RESET_COMMAND, - return_parameters = bytes([7]) + num_hci_command_packets=1, + command_opcode=HCI_RESET_COMMAND, + return_parameters=bytes([7]), ) basic_check(event) event = HCI_Packet.from_bytes(event.to_bytes()) @@ -137,9 +138,7 @@ def test_HCI_Command_Complete_Event(): # With a simple status as an integer status event = HCI_Command_Complete_Event( - num_hci_command_packets = 1, - command_opcode = HCI_RESET_COMMAND, - return_parameters = 9 + num_hci_command_packets=1, command_opcode=HCI_RESET_COMMAND, return_parameters=9 ) basic_check(event) assert event.return_parameters == 9 @@ -148,19 +147,14 @@ def test_HCI_Command_Complete_Event(): # ----------------------------------------------------------------------------- def test_HCI_Command_Status_Event(): event = HCI_Command_Status_Event( - status = 0, - num_hci_command_packets = 37, - command_opcode = HCI_DISCONNECT_COMMAND + status=0, num_hci_command_packets=37, command_opcode=HCI_DISCONNECT_COMMAND ) basic_check(event) # ----------------------------------------------------------------------------- def test_HCI_Number_Of_Completed_Packets_Event(): - event = HCI_Number_Of_Completed_Packets_Event([ - (1, 2), - (3, 4) - ]) + event = HCI_Number_Of_Completed_Packets_Event([(1, 2), (3, 4)]) basic_check(event) @@ -199,25 +193,20 @@ def test_HCI_Read_Local_Supported_Features_Command(): # ----------------------------------------------------------------------------- def test_HCI_Disconnect_Command(): - command = HCI_Disconnect_Command( - connection_handle = 123, - reason = 0x11 - ) + command = HCI_Disconnect_Command(connection_handle=123, reason=0x11) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_Set_Event_Mask_Command(): - command = HCI_Set_Event_Mask_Command( - event_mask = bytes.fromhex('0011223344556677') - ) + command = HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('0011223344556677')) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Event_Mask_Command(): command = HCI_LE_Set_Event_Mask_Command( - le_event_mask = bytes.fromhex('0011223344556677') + le_event_mask=bytes.fromhex('0011223344556677') ) basic_check(command) @@ -225,7 +214,7 @@ def test_HCI_LE_Set_Event_Mask_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Random_Address_Command(): command = HCI_LE_Set_Random_Address_Command( - random_address = Address('00:11:22:33:44:55') + random_address=Address('00:11:22:33:44:55') ) basic_check(command) @@ -233,14 +222,14 @@ def test_HCI_LE_Set_Random_Address_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Advertising_Parameters_Command(): command = HCI_LE_Set_Advertising_Parameters_Command( - advertising_interval_min = 20, - advertising_interval_max = 30, - advertising_type = HCI_LE_Set_Advertising_Parameters_Command.ADV_NONCONN_IND, - own_address_type = Address.PUBLIC_DEVICE_ADDRESS, - peer_address_type = Address.RANDOM_DEVICE_ADDRESS, - peer_address = Address('00:11:22:33:44:55'), - advertising_channel_map = 0x03, - advertising_filter_policy = 1 + advertising_interval_min=20, + advertising_interval_max=30, + advertising_type=HCI_LE_Set_Advertising_Parameters_Command.ADV_NONCONN_IND, + own_address_type=Address.PUBLIC_DEVICE_ADDRESS, + peer_address_type=Address.RANDOM_DEVICE_ADDRESS, + peer_address=Address('00:11:22:33:44:55'), + advertising_channel_map=0x03, + advertising_filter_policy=1, ) basic_check(command) @@ -248,7 +237,7 @@ def test_HCI_LE_Set_Advertising_Parameters_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Advertising_Data_Command(): command = HCI_LE_Set_Advertising_Data_Command( - advertising_data = bytes.fromhex('AABBCC') + advertising_data=bytes.fromhex('AABBCC') ) basic_check(command) @@ -256,39 +245,36 @@ def test_HCI_LE_Set_Advertising_Data_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Scan_Parameters_Command(): command = HCI_LE_Set_Scan_Parameters_Command( - le_scan_type = 1, - le_scan_interval = 20, - le_scan_window = 10, - own_address_type = 1, - scanning_filter_policy = 0 + le_scan_type=1, + le_scan_interval=20, + le_scan_window=10, + own_address_type=1, + scanning_filter_policy=0, ) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Scan_Enable_Command(): - command = HCI_LE_Set_Scan_Enable_Command( - le_scan_enable = 1, - filter_duplicates = 0 - ) + command = HCI_LE_Set_Scan_Enable_Command(le_scan_enable=1, filter_duplicates=0) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_LE_Create_Connection_Command(): command = HCI_LE_Create_Connection_Command( - le_scan_interval = 4, - le_scan_window = 5, - initiator_filter_policy = 1, - peer_address_type = 1, - peer_address = Address('00:11:22:33:44:55'), - own_address_type = 2, - connection_interval_min = 7, - connection_interval_max = 8, - max_latency = 9, - supervision_timeout = 10, - min_ce_length = 11, - max_ce_length = 12 + le_scan_interval=4, + le_scan_window=5, + initiator_filter_policy=1, + peer_address_type=1, + peer_address=Address('00:11:22:33:44:55'), + own_address_type=2, + connection_interval_min=7, + connection_interval_max=8, + max_latency=9, + supervision_timeout=10, + min_ce_length=11, + max_ce_length=12, ) basic_check(command) @@ -296,19 +282,19 @@ def test_HCI_LE_Create_Connection_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Extended_Create_Connection_Command(): command = HCI_LE_Extended_Create_Connection_Command( - initiator_filter_policy = 0, - own_address_type = 0, - peer_address_type = 1, - peer_address = Address('00:11:22:33:44:55'), - initiating_phys = 3, - scan_intervals = (10, 11), - scan_windows = (12, 13), - connection_interval_mins = (14, 15), - connection_interval_maxs = (16, 17), - max_latencies = (18, 19), - supervision_timeouts = (20, 21), - min_ce_lengths = (100, 101), - max_ce_lengths = (102, 103) + initiator_filter_policy=0, + own_address_type=0, + peer_address_type=1, + peer_address=Address('00:11:22:33:44:55'), + initiating_phys=3, + scan_intervals=(10, 11), + scan_windows=(12, 13), + connection_interval_mins=(14, 15), + connection_interval_maxs=(16, 17), + max_latencies=(18, 19), + supervision_timeouts=(20, 21), + min_ce_lengths=(100, 101), + max_ce_lengths=(102, 103), ) basic_check(command) @@ -316,8 +302,7 @@ def test_HCI_LE_Extended_Create_Connection_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Add_Device_To_Filter_Accept_List_Command(): command = HCI_LE_Add_Device_To_Filter_Accept_List_Command( - address_type = 1, - address = Address('00:11:22:33:44:55') + address_type=1, address=Address('00:11:22:33:44:55') ) basic_check(command) @@ -325,8 +310,7 @@ def test_HCI_LE_Add_Device_To_Filter_Accept_List_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Remove_Device_From_Filter_Accept_List_Command(): command = HCI_LE_Remove_Device_From_Filter_Accept_List_Command( - address_type = 1, - address = Address('00:11:22:33:44:55') + address_type=1, address=Address('00:11:22:33:44:55') ) basic_check(command) @@ -334,32 +318,26 @@ def test_HCI_LE_Remove_Device_From_Filter_Accept_List_Command(): # ----------------------------------------------------------------------------- def test_HCI_LE_Connection_Update_Command(): command = HCI_LE_Connection_Update_Command( - connection_handle = 0x0002, - connection_interval_min = 10, - connection_interval_max = 20, - max_latency = 7, - supervision_timeout = 3, - min_ce_length = 100, - max_ce_length = 200 + connection_handle=0x0002, + connection_interval_min=10, + connection_interval_max=20, + max_latency=7, + supervision_timeout=3, + min_ce_length=100, + max_ce_length=200, ) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_LE_Read_Remote_Features_Command(): - command = HCI_LE_Read_Remote_Features_Command( - connection_handle = 0x0002 - ) + command = HCI_LE_Read_Remote_Features_Command(connection_handle=0x0002) basic_check(command) # ----------------------------------------------------------------------------- def test_HCI_LE_Set_Default_PHY_Command(): - command = HCI_LE_Set_Default_PHY_Command( - all_phys = 0, - tx_phys = 1, - rx_phys = 1 - ) + command = HCI_LE_Set_Default_PHY_Command(all_phys=0, tx_phys=1, rx_phys=1) basic_check(command) @@ -372,10 +350,10 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command(): scan_types=[ HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING, HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING, - HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING + HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING, ], scan_intervals=[1, 2, 3], - scan_windows=[4, 5, 6] + scan_windows=[4, 5, 6], ) basic_check(command) diff --git a/tests/import_test.py b/tests/import_test.py index 4a83ceab..e0b6e3ca 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -38,7 +38,7 @@ def test_import(): sdp, smp, transport, - utils + utils, ) assert att @@ -65,36 +65,47 @@ def test_import(): # ----------------------------------------------------------------------------- def test_app_imports(): from apps.console import main + assert main from apps.controller_info import main + assert main from apps.controllers import main + assert main from apps.gatt_dump import main + assert main from apps.gg_bridge import main + assert main from apps.hci_bridge import main + assert main from apps.pair import main + assert main from apps.scan import main + assert main from apps.show import main + assert main from apps.unbond import main + assert main from apps.usb_probe import main + assert main @@ -103,7 +114,7 @@ def test_profiles_imports(): from bumble.profiles import ( battery_service, device_information_service, - heart_rate_service + heart_rate_service, ) assert battery_service diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 82d53864..6f8e1810 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -27,9 +27,7 @@ from bumble.device import Device from bumble.host import Host from bumble.transport import AsyncPipeSink from bumble.core import ProtocolError -from bumble.l2cap import ( - L2CAP_Connection_Request -) +from bumble.l2cap import L2CAP_Connection_Request # ----------------------------------------------------------------------------- @@ -45,18 +43,18 @@ class TwoDevices: self.link = LocalLink() self.controllers = [ - Controller('C1', link = self.link), - Controller('C2', link = self.link) + Controller('C1', link=self.link), + Controller('C2', link=self.link), ] self.devices = [ Device( - address = 'F0:F1:F2:F3:F4:F5', - host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) + address='F0:F1:F2:F3:F4:F5', + host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address = 'F5:F4:F3:F2:F1:F0', - host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) - ) + address='F5:F4:F3:F2:F1:F0', + host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), + ), ] self.paired = [None, None] @@ -74,8 +72,12 @@ async def setup_connection(): two_devices = TwoDevices() # Attach listeners - two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) - two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) + two_devices.devices[0].on( + 'connection', lambda connection: two_devices.on_connection(0, connection) + ) + two_devices.devices[1].on( + 'connection', lambda connection: two_devices.on_connection(1, connection) + ) # Start await two_devices.devices[0].power_on() @@ -102,19 +104,25 @@ def test_helpers(): psm = L2CAP_Connection_Request.serialize_psm(0x242311) assert psm == bytes([0x11, 0x23, 0x24]) - (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1) + (offset, psm) = L2CAP_Connection_Request.parse_psm( + bytes([0x00, 0x01, 0x00, 0x44]), 1 + ) assert offset == 3 assert psm == 0x01 - (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1) + (offset, psm) = L2CAP_Connection_Request.parse_psm( + bytes([0x00, 0x23, 0x10, 0x44]), 1 + ) assert offset == 3 assert psm == 0x1023 - (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1) + (offset, psm) = L2CAP_Connection_Request.parse_psm( + bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1 + ) assert offset == 4 assert psm == 0x242311 - rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44) + rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44) brq = bytes(rq) srq = L2CAP_Connection_Request.from_bytes(brq) assert srq.psm == rq.psm @@ -147,11 +155,7 @@ async def test_basic_connection(): devices.devices[1].register_l2cap_channel_server(psm, on_coc) l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) - messages = ( - bytes([1, 2, 3]), - bytes([4, 5, 6]), - bytes(10000) - ) + messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000)) for message in messages: l2cap_channel.write(message) await asyncio.sleep(0) @@ -191,18 +195,11 @@ async def transfer_payload(max_credits, mtu, mps): channel.sink = on_data psm = devices.devices[1].register_l2cap_channel_server( - psm = 0, - server = on_coc, - max_credits = max_credits, - mtu = mtu, - mps = mps + psm=0, server=on_coc, max_credits=max_credits, mtu=mtu, mps=mps ) l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) - messages = [ - bytes([1, 2, 3, 4, 5, 6, 7]) * x - for x in (3, 10, 100, 789) - ] + messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] for message in messages: l2cap_channel.write(message) await asyncio.sleep(0) @@ -233,7 +230,7 @@ async def test_bidirectional_transfer(): client_received = [] server_received = [] - server_channel = None + server_channel = None def on_server_coc(channel): nonlocal server_channel @@ -251,10 +248,7 @@ async def test_bidirectional_transfer(): client_channel = await devices.connections[0].open_l2cap_channel(psm) client_channel.sink = on_client_data - messages = [ - bytes([1, 2, 3, 4, 5, 6, 7]) * x - for x in (3, 10, 100) - ] + messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)] for message in messages: client_channel.write(message) await client_channel.drain() @@ -278,7 +272,8 @@ async def run(): await test_transfer() await test_bidirectional_transfer() + # ----------------------------------------------------------------------------- if __name__ == '__main__': - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(run()) diff --git a/tests/rfcomm_test.py b/tests/rfcomm_test.py index 6f1c44f4..94654623 100644 --- a/tests/rfcomm_test.py +++ b/tests/rfcomm_test.py @@ -30,10 +30,10 @@ def basic_frame_check(x): parsed_bytes = bytes(parsed) if len(serialized) < 500: print('Parsed Bytes:', parsed_bytes.hex()) - assert(parsed_bytes == serialized) + assert parsed_bytes == serialized x_str = str(x) parsed_str = str(parsed) - assert(x_str == parsed_str) + assert x_str == parsed_str # ----------------------------------------------------------------------------- diff --git a/tests/sdp_test.py b/tests/sdp_test.py index d7cf8be1..e117ae81 100644 --- a/tests/sdp_test.py +++ b/tests/sdp_test.py @@ -31,10 +31,10 @@ def basic_check(x): parsed_bytes = bytes(parsed) if len(serialized) < 500: print('Parsed Bytes:', parsed_bytes.hex()) - assert(parsed_bytes == serialized) + assert parsed_bytes == serialized x_str = str(x) parsed_str = str(parsed) - assert(x_str == parsed_str) + assert x_str == parsed_str # ----------------------------------------------------------------------------- @@ -99,19 +99,25 @@ def test_data_elements(): e = DataElement(DataElement.SEQUENCE, [DataElement(DataElement.BOOLEAN, True)]) basic_check(e) - e = DataElement(DataElement.SEQUENCE, [ - DataElement(DataElement.BOOLEAN, True), - DataElement(DataElement.TEXT_STRING, 'hello') - ]) + e = DataElement( + DataElement.SEQUENCE, + [ + DataElement(DataElement.BOOLEAN, True), + DataElement(DataElement.TEXT_STRING, 'hello'), + ], + ) basic_check(e) e = DataElement(DataElement.ALTERNATIVE, [DataElement(DataElement.BOOLEAN, True)]) basic_check(e) - e = DataElement(DataElement.ALTERNATIVE, [ - DataElement(DataElement.BOOLEAN, True), - DataElement(DataElement.TEXT_STRING, 'hello') - ]) + e = DataElement( + DataElement.ALTERNATIVE, + [ + DataElement(DataElement.BOOLEAN, True), + DataElement(DataElement.TEXT_STRING, 'hello'), + ], + ) basic_check(e) e = DataElement(DataElement.URL, 'http://example.com') @@ -133,10 +139,14 @@ def test_data_elements(): e = DataElement.boolean(True) basic_check(e) - e = DataElement.sequence([DataElement.signed_integer(0, 1), DataElement.text_string('hello')]) + e = DataElement.sequence( + [DataElement.signed_integer(0, 1), DataElement.text_string('hello')] + ) basic_check(e) - e = DataElement.alternative([DataElement.signed_integer(0, 1), DataElement.text_string('hello')]) + e = DataElement.alternative( + [DataElement.signed_integer(0, 1), DataElement.text_string('hello')] + ) basic_check(e) e = DataElement.url('http://foobar.com') diff --git a/tests/self_test.py b/tests/self_test.py index 7316b2fe..a33f14c6 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -50,18 +50,18 @@ class TwoDevices: self.link = LocalLink() self.controllers = [ - Controller('C1', link = self.link), - Controller('C2', link = self.link) + Controller('C1', link=self.link), + Controller('C2', link=self.link), ] self.devices = [ Device( - address = 'F0:F1:F2:F3:F4:F5', - host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) + address='F0:F1:F2:F3:F4:F5', + host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address = 'F5:F4:F3:F2:F1:F0', - host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) - ) + address='F5:F4:F3:F2:F1:F0', + host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), + ), ] self.paired = [None, None] @@ -80,8 +80,12 @@ async def test_self_connection(): two_devices = TwoDevices() # Attach listeners - two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) - two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) + two_devices.devices[0].on( + 'connection', lambda connection: two_devices.on_connection(0, connection) + ) + two_devices.devices[1].on( + 'connection', lambda connection: two_devices.on_connection(1, connection) + ) # Start await two_devices.devices[0].power_on() @@ -91,8 +95,8 @@ async def test_self_connection(): await two_devices.devices[0].connect(two_devices.devices[1].random_address) # Check the post conditions - assert(two_devices.connections[0] is not None) - assert(two_devices.connections[1] is not None) + assert two_devices.connections[0] is not None + assert two_devices.connections[1] is not None # ----------------------------------------------------------------------------- @@ -106,25 +110,25 @@ async def test_self_gatt(): '3A143AD7-D4A7-436B-97D6-5B62C315E833', Characteristic.READ, Characteristic.READABLE, - bytes([1, 2, 3]) + bytes([1, 2, 3]), ) c2 = Characteristic( '9557CCE2-DB37-46EB-94C4-50AE5B9CB0F8', Characteristic.READ | Characteristic.WRITE, Characteristic.READABLE | Characteristic.WRITEABLE, - bytes([4, 5, 6]) + bytes([4, 5, 6]), ) c3 = Characteristic( '84FC1A2E-C52D-4A2D-B8C3-8855BAB86638', Characteristic.READ | Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.READABLE | Characteristic.WRITEABLE, - bytes([7, 8, 9]) + bytes([7, 8, 9]), ) c4 = Characteristic( '84FC1A2E-C52D-4A2D-B8C3-8855BAB86638', Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE, Characteristic.READABLE, - bytes([1, 1, 1]) + bytes([1, 1, 1]), ) s1 = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', [c1, c2, c3]) @@ -136,31 +140,33 @@ async def test_self_gatt(): await two_devices.devices[1].power_on() # Connect the two devices - connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address) + connection = await two_devices.devices[0].connect( + two_devices.devices[1].random_address + ) peer = Peer(connection) bogus_uuid = 'A0AA6007-0B48-4BBE-80AC-0DE9AAF541EA' result = await peer.discover_services([bogus_uuid]) - assert(result == []) + assert result == [] services = peer.get_services_by_uuid(bogus_uuid) - assert(len(services) == 0) + assert len(services) == 0 result = await peer.discover_service(s1.uuid) - assert(len(result) == 1) + assert len(result) == 1 services = peer.get_services_by_uuid(s1.uuid) - assert(len(services) == 1) + assert len(services) == 1 s = services[0] - assert(services[0].uuid == s1.uuid) + assert services[0].uuid == s1.uuid result = await peer.discover_characteristics([c1.uuid], s) - assert(len(result) == 1) + assert len(result) == 1 characteristics = peer.get_characteristics_by_uuid(c1.uuid) - assert(len(characteristics) == 1) + assert len(characteristics) == 1 c = characteristics[0] - assert(c.uuid == c1.uuid) + assert c.uuid == c1.uuid result = await peer.read_value(c) - assert(result is not None) - assert(result == c1.value) + assert result is not None + assert result == c1.value # ----------------------------------------------------------------------------- @@ -175,7 +181,7 @@ async def test_self_gatt_long_read(): f'3A143AD7-D4A7-436B-97D6-5B62C315{i:04X}', Characteristic.READ, Characteristic.READABLE, - bytes([x & 255 for x in range(i)]) + bytes([x & 255 for x in range(i)]), ) for i in range(0, 513) ] @@ -188,17 +194,19 @@ async def test_self_gatt_long_read(): await two_devices.devices[1].power_on() # Connect the two devices - connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address) + connection = await two_devices.devices[0].connect( + two_devices.devices[1].random_address + ) peer = Peer(connection) result = await peer.discover_service(service.uuid) - assert(len(result) == 1) + assert len(result) == 1 found_service = result[0] found_characteristics = await found_service.discover_characteristics() - assert(len(found_characteristics) == 513) + assert len(found_characteristics) == 513 for (i, characteristic) in enumerate(found_characteristics): value = await characteristic.read_value() - assert(value == characteristics[i].value) + assert value == characteristics[i].value # ----------------------------------------------------------------------------- @@ -211,28 +219,42 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2): await two_devices.devices[1].power_on() # Attach listeners - two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) - two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) + two_devices.devices[0].on( + 'connection', lambda connection: two_devices.on_connection(0, connection) + ) + two_devices.devices[1].on( + 'connection', lambda connection: two_devices.on_connection(1, connection) + ) # Connect the two devices - connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address) - assert(not connection.is_encrypted) + connection = await two_devices.devices[0].connect( + two_devices.devices[1].random_address + ) + assert not connection.is_encrypted # Attach connection listeners - two_devices.connections[0].on('pairing', lambda keys: two_devices.on_paired(0, keys)) - two_devices.connections[1].on('pairing', lambda keys: two_devices.on_paired(1, keys)) + two_devices.connections[0].on( + 'pairing', lambda keys: two_devices.on_paired(0, keys) + ) + two_devices.connections[1].on( + 'pairing', lambda keys: two_devices.on_paired(1, keys) + ) # Set up the pairing configs if pairing_config1: - two_devices.devices[0].pairing_config_factory = lambda connection: pairing_config1 + two_devices.devices[ + 0 + ].pairing_config_factory = lambda connection: pairing_config1 if pairing_config2: - two_devices.devices[1].pairing_config_factory = lambda connection: pairing_config2 + two_devices.devices[ + 1 + ].pairing_config_factory = lambda connection: pairing_config2 # Pair await two_devices.devices[0].pair(connection) - assert(connection.is_encrypted) - assert(two_devices.paired[0] is not None) - assert(two_devices.paired[1] is not None) + assert connection.is_encrypted + assert two_devices.paired[0] is not None + assert two_devices.paired[1] is not None # ----------------------------------------------------------------------------- @@ -241,22 +263,32 @@ IO_CAP = [ PairingDelegate.KEYBOARD_INPUT_ONLY, PairingDelegate.DISPLAY_OUTPUT_ONLY, PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, - PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT + PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, ] SC = [False, True] MITM = [False, True] # Key distribution is a 4-bit bitmask KEY_DIST = range(16) + @pytest.mark.asyncio -@pytest.mark.parametrize('io_cap, sc, mitm, key_dist', - itertools.product(IO_CAP, SC, MITM, KEY_DIST) +@pytest.mark.parametrize( + 'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST) ) async def test_self_smp(io_cap, sc, mitm, key_dist): class Delegate(PairingDelegate): - def __init__(self, name, io_capability, local_initiator_key_distribution, local_responder_key_distribution): - super().__init__(io_capability, local_initiator_key_distribution, - local_responder_key_distribution) + def __init__( + self, + name, + io_capability, + local_initiator_key_distribution, + local_responder_key_distribution, + ): + super().__init__( + io_capability, + local_initiator_key_distribution, + local_responder_key_distribution, + ) self.name = name self.reset() @@ -279,7 +311,10 @@ async def test_self_smp(io_cap, sc, mitm, key_dist): logger.warn(f'[{self.name}] no peer delegate') return 0 else: - if self.peer_delegate.io_capability == PairingDelegate.KEYBOARD_INPUT_ONLY: + if ( + self.peer_delegate.io_capability + == PairingDelegate.KEYBOARD_INPUT_ONLY + ): peer_number = 6789 else: logger.debug(f'[{self.name}] waiting for peer number') @@ -301,7 +336,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist): for pairing_config1 in pairing_config_sets[0][1]: for pairing_config2 in pairing_config_sets[1][1]: - logger.info(f'########## self_smp with {pairing_config1} and {pairing_config2}') + logger.info( + f'########## self_smp with {pairing_config1} and {pairing_config2}' + ) if pairing_config1: pairing_config1.delegate.reset() if pairing_config2: @@ -313,7 +350,6 @@ async def test_self_smp(io_cap, sc, mitm, key_dist): await _test_self_smp_with_configs(pairing_config1, pairing_config2) - # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_self_smp_reject(): @@ -324,15 +360,15 @@ async def test_self_smp_reject(): async def accept(self): return False - rejecting_pairing_config = PairingConfig(delegate = RejectingDelegate()) + rejecting_pairing_config = PairingConfig(delegate=RejectingDelegate()) paired = False try: await _test_self_smp_with_configs(None, rejecting_pairing_config) paired = True except ProtocolError as error: - assert(error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR) + assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR - assert(not paired) + assert not paired # ----------------------------------------------------------------------------- @@ -345,15 +381,17 @@ async def test_self_smp_wrong_pin(): async def compare_numbers(self, number, digits): return False - wrong_pin_pairing_config = PairingConfig(delegate = WrongPinDelegate()) + wrong_pin_pairing_config = PairingConfig(delegate=WrongPinDelegate()) paired = False try: - await _test_self_smp_with_configs(wrong_pin_pairing_config, wrong_pin_pairing_config) + await _test_self_smp_with_configs( + wrong_pin_pairing_config, wrong_pin_pairing_config + ) paired = True except ProtocolError as error: - assert(error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR) + assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR - assert(not paired) + assert not paired # ----------------------------------------------------------------------------- @@ -365,7 +403,8 @@ async def run_test_self(): await test_self_smp_reject() await test_self_smp_wrong_pin() + # ----------------------------------------------------------------------------- if __name__ == '__main__': - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) asyncio.run(run_test_self()) diff --git a/tests/smp_test.py b/tests/smp_test.py index 9120c477..bc63c525 100644 --- a/tests/smp_test.py +++ b/tests/smp_test.py @@ -29,39 +29,41 @@ def test_ecc(): x = key.x y = key.y - assert(len(x) == 32) - assert(len(y) == 32) + assert len(x) == 32 + assert len(y) == 32 # Test DH with test vectors from the spec - private_A = '3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd' - private_B = '55188b3d 32f6bb9a 900afcfb eed4e72a 59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' - public_A_x = '20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6' - public_A_y = 'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b' - public_B_x = '1ea1f0f0 1faf1d96 09592284 f19e4c00 47b58afd 8615a69f 559077b2 2faaa190' - public_B_y = '4c55f33e 429dad37 7356703a 9ab85160 472d1130 e28e3676 5f89aff9 15b1214a' - dhkey = 'ec0234a3 57c8ad05 341010a6 0a397d9b 99796b13 b4f866f1 868d34f3 73bfa698' + private_A = ( + '3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd' + ) + private_B = ( + '55188b3d 32f6bb9a 900afcfb eed4e72a 59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' + ) + public_A_x = ( + '20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6' + ) + public_A_y = ( + 'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b' + ) + public_B_x = ( + '1ea1f0f0 1faf1d96 09592284 f19e4c00 47b58afd 8615a69f 559077b2 2faaa190' + ) + public_B_y = ( + '4c55f33e 429dad37 7356703a 9ab85160 472d1130 e28e3676 5f89aff9 15b1214a' + ) + dhkey = 'ec0234a3 57c8ad05 341010a6 0a397d9b 99796b13 b4f866f1 868d34f3 73bfa698' key_a = EccKey.from_private_key_bytes( - bytes.fromhex(private_A), - bytes.fromhex(public_A_x), - bytes.fromhex(public_A_y) + bytes.fromhex(private_A), bytes.fromhex(public_A_x), bytes.fromhex(public_A_y) ) - shared_key = key_a.dh( - bytes.fromhex(public_B_x), - bytes.fromhex(public_B_y) - ) - assert(shared_key == bytes.fromhex(dhkey)) + shared_key = key_a.dh(bytes.fromhex(public_B_x), bytes.fromhex(public_B_y)) + assert shared_key == bytes.fromhex(dhkey) key_b = EccKey.from_private_key_bytes( - bytes.fromhex(private_B), - bytes.fromhex(public_B_x), - bytes.fromhex(public_B_y) + bytes.fromhex(private_B), bytes.fromhex(public_B_x), bytes.fromhex(public_B_y) ) - shared_key = key_b.dh( - bytes.fromhex(public_A_x), - bytes.fromhex(public_A_y) - ) - assert(shared_key == bytes.fromhex(dhkey)) + shared_key = key_b.dh(bytes.fromhex(public_A_x), bytes.fromhex(public_A_y)) + assert shared_key == bytes.fromhex(dhkey) # ----------------------------------------------------------------------------- @@ -75,7 +77,7 @@ def test_c1(): rat = 0 ra = reversed_hex('B1B2B3B4B5B6') result = c1(k, r, preq, pres, iat, rat, ia, ra) - assert(result == reversed_hex('1e1e3fef878988ead2a74dc5bef13b86')) + assert result == reversed_hex('1e1e3fef878988ead2a74dc5bef13b86') # ----------------------------------------------------------------------------- @@ -84,7 +86,7 @@ def test_s1(): r1 = reversed_hex('000F0E0D0C0B0A091122334455667788') r2 = reversed_hex('010203040506070899AABBCCDDEEFF00') result = s1(k, r1, r2) - assert(result == reversed_hex('9a1fe1f0e8b0f49b5b4216ae796da062')) + assert result == reversed_hex('9a1fe1f0e8b0f49b5b4216ae796da062') # ----------------------------------------------------------------------------- @@ -92,59 +94,77 @@ def test_aes_cmac(): m = b'' k = bytes.fromhex('2b7e1516 28aed2a6 abf71588 09cf4f3c') cmac = aes_cmac(m, k) - assert(cmac == bytes.fromhex('bb1d6929 e9593728 7fa37d12 9b756746')) + assert cmac == bytes.fromhex('bb1d6929 e9593728 7fa37d12 9b756746') m = bytes.fromhex('6bc1bee2 2e409f96 e93d7e11 7393172a') cmac = aes_cmac(m, k) - assert(cmac == bytes.fromhex('070a16b4 6b4d4144 f79bdd9d d04a287c')) + assert cmac == bytes.fromhex('070a16b4 6b4d4144 f79bdd9d d04a287c') m = bytes.fromhex( - '6bc1bee2 2e409f96 e93d7e11 7393172a' + - 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51' + - '30c81c46 a35ce411' + '6bc1bee2 2e409f96 e93d7e11 7393172a' + + 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51' + + '30c81c46 a35ce411' ) cmac = aes_cmac(m, k) - assert(cmac == bytes.fromhex('dfa66747 de9ae630 30ca3261 1497c827')) + assert cmac == bytes.fromhex('dfa66747 de9ae630 30ca3261 1497c827') m = bytes.fromhex( - '6bc1bee2 2e409f96 e93d7e11 7393172a' + - 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51' + - '30c81c46 a35ce411 e5fbc119 1a0a52ef' + - 'f69f2445 df4f9b17 ad2b417b e66c3710' + '6bc1bee2 2e409f96 e93d7e11 7393172a' + + 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51' + + '30c81c46 a35ce411 e5fbc119 1a0a52ef' + + 'f69f2445 df4f9b17 ad2b417b e66c3710' ) cmac = aes_cmac(m, k) - assert(cmac == bytes.fromhex('51f0bebf 7e3b9d92 fc497417 79363cfe')) + assert cmac == bytes.fromhex('51f0bebf 7e3b9d92 fc497417 79363cfe') # ----------------------------------------------------------------------------- def test_f4(): - u = bytes(reversed(bytes.fromhex( - '20b003d2 f297be2c 5e2c83a7 e9f9a5b9' + - 'eff49111 acf4fddb cc030148 0e359de6' - ))) - v = bytes(reversed(bytes.fromhex( - '55188b3d 32f6bb9a 900afcfb eed4e72a' + - '59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' - ))) + u = bytes( + reversed( + bytes.fromhex( + '20b003d2 f297be2c 5e2c83a7 e9f9a5b9' + + 'eff49111 acf4fddb cc030148 0e359de6' + ) + ) + ) + v = bytes( + reversed( + bytes.fromhex( + '55188b3d 32f6bb9a 900afcfb eed4e72a' + + '59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' + ) + ) + ) x = bytes(reversed(bytes.fromhex('d5cb8454 d177733e ffffb2ec 712baeab'))) z = bytes([0]) value = f4(u, v, x, z) - assert(bytes(reversed(value)) == bytes.fromhex('f2c916f1 07a9bd1c f1eda1be a974872d')) + assert bytes(reversed(value)) == bytes.fromhex( + 'f2c916f1 07a9bd1c f1eda1be a974872d' + ) # ----------------------------------------------------------------------------- def test_f5(): - w = bytes(reversed(bytes.fromhex( - 'ec0234a3 57c8ad05 341010a6 0a397d9b' + - '99796b13 b4f866f1 868d34f3 73bfa698' - ))) + w = bytes( + reversed( + bytes.fromhex( + 'ec0234a3 57c8ad05 341010a6 0a397d9b' + + '99796b13 b4f866f1 868d34f3 73bfa698' + ) + ) + ) n1 = bytes(reversed(bytes.fromhex('d5cb8454 d177733e ffffb2ec 712baeab'))) n2 = bytes(reversed(bytes.fromhex('a6e8e7cc 25a75f6e 216583f7 ff3dc4cf'))) a1 = bytes(reversed(bytes.fromhex('00561237 37bfce'))) a2 = bytes(reversed(bytes.fromhex('00a71370 2dcfc1'))) value = f5(w, n1, n2, a1, a2) - assert(bytes(reversed(value[0])) == bytes.fromhex('2965f176 a1084a02 fd3f6a20 ce636e20')) - assert(bytes(reversed(value[1])) == bytes.fromhex('69867911 69d7cd23 980522b5 94750a38')) + assert bytes(reversed(value[0])) == bytes.fromhex( + '2965f176 a1084a02 fd3f6a20 ce636e20' + ) + assert bytes(reversed(value[1])) == bytes.fromhex( + '69867911 69d7cd23 980522b5 94750a38' + ) # ----------------------------------------------------------------------------- @@ -157,37 +177,47 @@ def test_f6(): a1 = bytes(reversed(bytes.fromhex('00561237 37bfce'))) a2 = bytes(reversed(bytes.fromhex('00a71370 2dcfc1'))) value = f6(mac_key, n1, n2, r, io_cap, a1, a2) - assert(bytes(reversed(value)) == bytes.fromhex('e3c47398 9cd0e8c5 d26c0b09 da958f61')) + assert bytes(reversed(value)) == bytes.fromhex( + 'e3c47398 9cd0e8c5 d26c0b09 da958f61' + ) # ----------------------------------------------------------------------------- def test_g2(): - u = bytes(reversed(bytes.fromhex( - '20b003d2 f297be2c 5e2c83a7 e9f9a5b9' + - 'eff49111 acf4fddb cc030148 0e359de6' - ))) - v = bytes(reversed(bytes.fromhex( - '55188b3d 32f6bb9a 900afcfb eed4e72a' + - '59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' - ))) + u = bytes( + reversed( + bytes.fromhex( + '20b003d2 f297be2c 5e2c83a7 e9f9a5b9' + + 'eff49111 acf4fddb cc030148 0e359de6' + ) + ) + ) + v = bytes( + reversed( + bytes.fromhex( + '55188b3d 32f6bb9a 900afcfb eed4e72a' + + '59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd' + ) + ) + ) x = bytes(reversed(bytes.fromhex('d5cb8454 d177733e ffffb2ec 712baeab'))) y = bytes(reversed(bytes.fromhex('a6e8e7cc 25a75f6e 216583f7 ff3dc4cf'))) value = g2(u, v, x, y) - assert(value == 0x2f9ed5ba) + assert value == 0x2F9ED5BA # ----------------------------------------------------------------------------- def test_h6(): KEY = bytes.fromhex('ec0234a3 57c8ad05 341010a6 0a397d9b') KEY_ID = bytes.fromhex('6c656272') - assert(h6(KEY, KEY_ID) == bytes.fromhex('2d9ae102 e76dc91c e8d3a9e2 80b16399')) + assert h6(KEY, KEY_ID) == bytes.fromhex('2d9ae102 e76dc91c e8d3a9e2 80b16399') # ----------------------------------------------------------------------------- def test_h7(): KEY = bytes.fromhex('ec0234a3 57c8ad05 341010a6 0a397d9b') SALT = bytes.fromhex('00000000 00000000 00000000 746D7031') - assert(h7(SALT, KEY) == bytes.fromhex('fb173597 c6a3c0ec d2998c2a 75a57011')) + assert h7(SALT, KEY) == bytes.fromhex('fb173597 c6a3c0ec d2998c2a 75a57011') # ----------------------------------------------------------------------------- @@ -196,7 +226,7 @@ def test_ah(): prand = bytes(reversed(bytes.fromhex('708194'))) value = ah(irk, prand) expected = bytes(reversed(bytes.fromhex('0dfbaa'))) - assert(value == expected) + assert value == expected # ----------------------------------------------------------------------------- diff --git a/tests/transport_test.py b/tests/transport_test.py index 30053452..c0069a06 100644 --- a/tests/transport_test.py +++ b/tests/transport_test.py @@ -37,7 +37,9 @@ def test_parser(): parser2 = PacketParser(sink2) for parser in [parser1, parser2]: - with open(os.path.join(os.path.dirname(__file__), 'hci_data_001.bin'), 'rb') as input: + with open( + os.path.join(os.path.dirname(__file__), 'hci_data_001.bin'), 'rb' + ) as input: while True: n = random.randint(1, 9) data = input.read(n) @@ -45,7 +47,7 @@ def test_parser(): break parser.feed_data(data) - assert(sink1.packets == sink2.packets) + assert sink1.packets == sink2.packets # ----------------------------------------------------------------------------- @@ -60,15 +62,15 @@ def test_parser_extensions(): except ValueError: exception_thrown = True - assert(exception_thrown) + assert exception_thrown # Now add a custom info parser.extended_packet_info[0x77] = (1, 1, 'B') parser.reset() parser.feed_data(bytes([0x77, 0x00, 0x02, 0x01, 0x02])) - assert(len(sink.packets) == 1) + assert len(sink.packets) == 1 # ----------------------------------------------------------------------------- test_parser() -test_parser_extensions() \ No newline at end of file +test_parser_extensions() diff --git a/web/scanner.py b/web/scanner.py index e734dbf2..1f9d0e8c 100644 --- a/web/scanner.py +++ b/web/scanner.py @@ -23,7 +23,9 @@ from bumble.transport import PacketParser class ScannerListener(Device.Listener): def on_advertisement(self, advertisement): address_type_string = ('P', 'R', 'PI', 'RI')[advertisement.address.address_type] - print(f'>>> {advertisement.address} [{address_type_string}]: RSSI={advertisement.rssi}, {advertisement.ad_data}') + print( + f'>>> {advertisement.address} [{address_type_string}]: RSSI={advertisement.rssi}, {advertisement.ad_data}' + ) class HciSource: