Add runtime switch for filtering by address.

* scan on [filter pattern]
* filter address <filter pattern>
This commit is contained in:
Alan Rosenthal
2022-10-20 14:47:14 -04:00
parent eb8556ccf6
commit 3d79d7def5

View File

@@ -27,6 +27,7 @@ import logging
import click
from collections import OrderedDict
import colors
import re
from bumble.core import UUID, AdvertisingData, TimeoutError, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
@@ -166,6 +167,9 @@ class ConsoleApp:
'attributes': None,
'log': None
},
'filter': {
'address': None,
},
'connect': LiveCompleter(self.known_addresses),
'update-parameters': None,
'encrypt': None,
@@ -455,6 +459,15 @@ class ConsoleApp:
else:
await self.device.start_scanning()
elif params[0] == 'on':
if len(params) == 2:
if not params[1].startswith("filter="):
self.show_error('invalid syntax', 'expected address filter=key1:value1,key2:value,... available filters: address')
# regex: (word):(any char except ,)
matches = re.findall(r"(\w+):([^,]+)", params[1])
for match in matches:
if match[0] == "address":
self.device.listener.address_filter = match[1]
await self.device.start_scanning()
self.top_tab = 'scan'
elif params[0] == 'off':
@@ -708,6 +721,12 @@ class ConsoleApp:
async def do_quit(self, params):
self.ui.exit()
async def do_filter(self, params):
if params[0] == "address":
if len(params) != 2:
self.show_error('invalid syntax', 'expected filter address <pattern>')
return
self.device.listener.address_filter = params[1]
# -----------------------------------------------------------------------------
# Device and Connection Listener
@@ -716,6 +735,26 @@ class DeviceListener(Device.Listener, Connection.Listener):
def __init__(self, app):
self.app = app
self.scan_results = OrderedDict()
self.address_filter = None
@property
def address_filter(self):
return self._address_filter
@address_filter.setter
def address_filter(self, filter_addr):
if filter_addr is None:
self._address_filter = re.compile(r".*")
else:
self._address_filter = re.compile(filter_addr)
self.scan_results = OrderedDict(filter(lambda x: self.filter_address_match(x), self.scan_results))
self.app.show_scan_results(self.scan_results)
def filter_address_match(self, address):
"""
Returns true if an address matches the filter
"""
return bool(self.address_filter.match(address))
@AsyncRunner.run_in_task()
async def on_connection(self, connection):
@@ -745,6 +784,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
self.app.append_to_output(f'connection data length change: {self.app.connected_peer.connection.data_length}')
def on_advertisement(self, advertisement):
if not self.filter_address_match(str(advertisement.address)):
return
entry_key = f'{advertisement.address}/{advertisement.address.address_type}'
entry = self.scan_results.get(entry_key)
if entry:
@@ -802,6 +844,7 @@ class LogHandler(logging.Handler):
def __init__(self, app):
super().__init__()
self.app = app
self.setFormatter("[%(asctime)s][%(pathname)s:%(lineno)d][%(levelname)s] %(message)s")
def emit(self, record):
message = self.format(record)
@@ -826,6 +869,7 @@ def main(device_config, transport):
# logging.basicConfig(level = 'FATAL')
# logging.basicConfig(level = 'DEBUG')
root_logger = logging.getLogger()
root_logger.addHandler(LogHandler(app))
root_logger.setLevel(logging.DEBUG)